From fb32c388985ce65c1083cb435cf1f7479fecbaac Mon Sep 17 00:00:00 2001 From: MechCoder Date: Wed, 24 Jun 2015 14:58:43 -0700 Subject: [PATCH 001/122] [SPARK-7633] [MLLIB] [PYSPARK] Python bindings for StreamingLogisticRegressionwithSGD Add Python bindings to StreamingLogisticRegressionwithSGD. No Java wrappers are needed as models are updated directly using train. Author: MechCoder Closes #6849 from MechCoder/spark-3258 and squashes the following commits: b4376a5 [MechCoder] minor d7e5fc1 [MechCoder] Refactor into StreamingLinearAlgorithm Better docs 9c09d4e [MechCoder] [SPARK-7633] Python bindings for StreamingLogisticRegressionwithSGD --- python/pyspark/mllib/classification.py | 96 +++++++++++++++++- python/pyspark/mllib/tests.py | 135 ++++++++++++++++++++++++- 2 files changed, 229 insertions(+), 2 deletions(-) diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 758accf4b41eb..2698f10d06883 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -21,6 +21,7 @@ from numpy import array from pyspark import RDD +from pyspark.streaming import DStream from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper @@ -28,7 +29,8 @@ __all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS', - 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes'] + 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes', + 'StreamingLogisticRegressionWithSGD'] class LinearClassificationModel(LinearModel): @@ -583,6 +585,98 @@ def train(cls, data, lambda_=1.0): return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) +class StreamingLinearAlgorithm(object): + """ + Base class that has to be inherited by any StreamingLinearAlgorithm. + + Prevents reimplementation of methods predictOn and predictOnValues. + """ + def __init__(self, model): + self._model = model + + def latestModel(self): + """ + Returns the latest model. + """ + return self._model + + def _validate(self, dstream): + if not isinstance(dstream, DStream): + raise TypeError( + "dstream should be a DStream object, got %s" % type(dstream)) + if not self._model: + raise ValueError( + "Model must be intialized using setInitialWeights") + + def predictOn(self, dstream): + """ + Make predictions on a dstream. + + :return: Transformed dstream object. + """ + self._validate(dstream) + return dstream.map(lambda x: self._model.predict(x)) + + def predictOnValues(self, dstream): + """ + Make predictions on a keyed dstream. + + :return: Transformed dstream object. + """ + self._validate(dstream) + return dstream.mapValues(lambda x: self._model.predict(x)) + + +@inherit_doc +class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): + """ + Run LogisticRegression with SGD on a stream of data. + + The weights obtained at the end of training a stream are used as initial + weights for the next stream. + + :param stepSize: Step size for each iteration of gradient descent. + :param numIterations: Number of iterations run for each batch of data. + :param miniBatchFraction: Fraction of data on which SGD is run for each + iteration. + :param regParam: L2 Regularization parameter. + """ + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0, regParam=0.01): + self.stepSize = stepSize + self.numIterations = numIterations + self.regParam = regParam + self.miniBatchFraction = miniBatchFraction + self._model = None + super(StreamingLogisticRegressionWithSGD, self).__init__( + model=self._model) + + def setInitialWeights(self, initialWeights): + """ + Set the initial value of weights. + + This must be set before running trainOn and predictOn. + """ + initialWeights = _convert_to_vector(initialWeights) + + # LogisticRegressionWithSGD does only binary classification. + self._model = LogisticRegressionModel( + initialWeights, 0, initialWeights.size, 2) + return self + + def trainOn(self, dstream): + """Train the model on the incoming dstream.""" + self._validate(dstream) + + def update(rdd): + # LogisticRegressionWithSGD.train raises an error for an empty RDD. + if not rdd.isEmpty(): + self._model = LogisticRegressionWithSGD.train( + rdd, self.numIterations, self.stepSize, + self.miniBatchFraction, self._model.weights) + + dstream.foreachRDD(update) + + def _test(): import doctest from pyspark import SparkContext diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 509faa11df170..cd80c3e07a4f7 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -26,7 +26,8 @@ from time import time, sleep from shutil import rmtree -from numpy import array, array_equal, zeros, inf, all, random +from numpy import ( + array, array_equal, zeros, inf, random, exp, dot, all, mean) from numpy import sum as array_sum from py4j.protocol import Py4JJavaError @@ -45,6 +46,7 @@ from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics from pyspark.mllib.feature import Word2Vec @@ -1037,6 +1039,137 @@ def test_dim(self): self.assertEqual(len(point.features), 2) +class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): + + @staticmethod + def generateLogisticInput(offset, scale, nPoints, seed): + """ + Generate 1 / (1 + exp(-x * scale + offset)) + + where, + x is randomnly distributed and the threshold + and labels for each sample in x is obtained from a random uniform + distribution. + """ + rng = random.RandomState(seed) + x = rng.randn(nPoints) + sigmoid = 1. / (1 + exp(-(dot(x, scale) + offset))) + y_p = rng.rand(nPoints) + cut_off = y_p <= sigmoid + y_p[cut_off] = 1.0 + y_p[~cut_off] = 0.0 + return [ + LabeledPoint(y_p[i], Vectors.dense([x[i]])) + for i in range(nPoints)] + + def test_parameter_accuracy(self): + """ + Test that the final value of weights is close to the desired value. + """ + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + input_stream = self.ssc.queueStream(input_batches) + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + slr.trainOn(input_stream) + + t = time() + self.ssc.start() + self._ssc_wait(t, 20.0, 0.01) + rel = (1.5 - slr.latestModel().weights.array[0]) / 1.5 + self.assertAlmostEqual(rel, 0.1, 1) + + def test_convergence(self): + """ + Test that weights converge to the required value on toy data. + """ + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + input_stream = self.ssc.queueStream(input_batches) + models = [] + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + slr.trainOn(input_stream) + input_stream.foreachRDD( + lambda x: models.append(slr.latestModel().weights[0])) + + t = time() + self.ssc.start() + self._ssc_wait(t, 15.0, 0.01) + t_models = array(models) + diff = t_models[1:] - t_models[:-1] + + # Test that weights improve with a small tolerance, + self.assertTrue(all(diff >= -0.1)) + self.assertTrue(array_sum(diff > 0) > 1) + + @staticmethod + def calculate_accuracy_error(true, predicted): + return sum(abs(array(true) - array(predicted))) / len(true) + + def test_predictions(self): + """Test predicted values on a toy model.""" + input_batches = [] + for i in range(20): + batch = self.sc.parallelize( + self.generateLogisticInput(0, 1.5, 100, 42 + i)) + input_batches.append(batch.map(lambda x: (x.label, x.features))) + input_stream = self.ssc.queueStream(input_batches) + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.2, numIterations=25) + slr.setInitialWeights([1.5]) + predict_stream = slr.predictOnValues(input_stream) + true_predicted = [] + predict_stream.foreachRDD(lambda x: true_predicted.append(x.collect())) + t = time() + self.ssc.start() + self._ssc_wait(t, 5.0, 0.01) + + # Test that the accuracy error is no more than 0.4 on each batch. + for batch in true_predicted: + true, predicted = zip(*batch) + self.assertTrue( + self.calculate_accuracy_error(true, predicted) < 0.4) + + def test_training_and_prediction(self): + """Test that the model improves on toy data with no. of batches""" + input_batches = [ + self.sc.parallelize(self.generateLogisticInput(0, 1.5, 100, 42 + i)) + for i in range(20)] + predict_batches = [ + b.map(lambda lp: (lp.label, lp.features)) for b in input_batches] + + slr = StreamingLogisticRegressionWithSGD( + stepSize=0.01, numIterations=25) + slr.setInitialWeights([-0.1]) + errors = [] + + def collect_errors(rdd): + true, predicted = zip(*rdd.collect()) + errors.append(self.calculate_accuracy_error(true, predicted)) + + true_predicted = [] + input_stream = self.ssc.queueStream(input_batches) + predict_stream = self.ssc.queueStream(predict_batches) + slr.trainOn(input_stream) + ps = slr.predictOnValues(predict_stream) + ps.foreachRDD(lambda x: collect_errors(x)) + + t = time() + self.ssc.start() + self._ssc_wait(t, 20.0, 0.01) + + # Test that the improvement in error is atleast 0.3 + self.assertTrue(errors[1] - errors[-1] > 0.3) + + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") From 8ab50765cd793169091d983b50d87a391f6ac1f4 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Wed, 24 Jun 2015 15:03:43 -0700 Subject: [PATCH 002/122] [SPARK-6777] [SQL] Implements backwards compatibility rules in CatalystSchemaConverter This PR introduces `CatalystSchemaConverter` for converting Parquet schema to Spark SQL schema and vice versa. Original conversion code in `ParquetTypesConverter` is removed. Benefits of the new version are: 1. When converting Spark SQL schemas, it generates standard Parquet schemas conforming to [the most updated Parquet format spec] [1]. Converting to old style Parquet schemas is also supported via feature flag `spark.sql.parquet.followParquetFormatSpec` (which is set to `false` for now, and should be set to `true` after both read and write paths are fixed). Note that although this version of Parquet format spec hasn't been officially release yet, Parquet MR 1.7.0 already sticks to it. So it should be safe to follow. 1. It implements backwards-compatibility rules described in the most updated Parquet format spec. Thus can recognize more schema patterns generated by other/legacy systems/tools. 1. Code organization follows convention used in [parquet-mr] [2], which is easier to follow. (Structure of `CatalystSchemaConverter` is similar to `AvroSchemaConverter`). To fully implement backwards-compatibility rules in both read and write path, we also need to update `CatalystRowConverter` (which is responsible for converting Parquet records to `Row`s), `RowReadSupport`, and `RowWriteSupport`. These would be done in follow-up PRs. TODO - [x] More schema conversion test cases for legacy schema patterns. [1]: https://github.com/apache/parquet-format/blob/ea095226597fdbecd60c2419d96b54b2fdb4ae6c/LogicalTypes.md [2]: https://github.com/apache/parquet-mr/ Author: Cheng Lian Closes #6617 from liancheng/spark-6777 and squashes the following commits: 2a2062d [Cheng Lian] Don't convert decimals without precision information b60979b [Cheng Lian] Adds a constructor which accepts a Configuration, and fixes default value of assumeBinaryIsString 743730f [Cheng Lian] Decimal scale shouldn't be larger than precision a104a9e [Cheng Lian] Fixes Scala style issue 1f71d8d [Cheng Lian] Adds feature flag to allow falling back to old style Parquet schema conversion ba84f4b [Cheng Lian] Fixes MapType schema conversion bug 13cb8d5 [Cheng Lian] Fixes MiMa failure 81de5b0 [Cheng Lian] Fixes UDT, workaround read path, and add tests 28ef95b [Cheng Lian] More AnalysisExceptions b10c322 [Cheng Lian] Replaces require() with analysisRequire() which throws AnalysisException cceaf3f [Cheng Lian] Implements backwards compatibility rules in CatalystSchemaConverter --- project/MimaExcludes.scala | 7 +- .../apache/spark/sql/types/DecimalType.scala | 9 +- .../scala/org/apache/spark/sql/SQLConf.scala | 14 + .../sql/parquet/CatalystSchemaConverter.scala | 565 ++++++++++++++ .../sql/parquet/ParquetTableSupport.scala | 6 +- .../spark/sql/parquet/ParquetTypes.scala | 374 +-------- .../spark/sql/parquet/ParquetIOSuite.scala | 6 +- .../sql/parquet/ParquetSchemaSuite.scala | 736 ++++++++++++++++-- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 9 files changed, 1297 insertions(+), 422 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index f678c69a6dfa9..6f86a505b3ae4 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -69,7 +69,12 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]( "org.apache.spark.sql.parquet.CatalystTimestampConverter"), ProblemFilters.exclude[MissingClassProblem]( - "org.apache.spark.sql.parquet.CatalystTimestampConverter$") + "org.apache.spark.sql.parquet.CatalystTimestampConverter$"), + // SPARK-6777 Implements backwards compatibility rules in CatalystSchemaConverter + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTypeInfo"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.sql.parquet.ParquetTypeInfo$") ) case v if v.startsWith("1.4") => Seq( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index 407dc27326c2e..18cdfa7238f39 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -20,13 +20,18 @@ package org.apache.spark.sql.types import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.Expression /** Precision parameters for a Decimal */ -case class PrecisionInfo(precision: Int, scale: Int) - +case class PrecisionInfo(precision: Int, scale: Int) { + if (scale > precision) { + throw new AnalysisException( + s"Decimal scale ($scale) cannot be greater than precision ($precision).") + } +} /** * :: DeveloperApi :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 265352647fa9f..9a10a23937fbb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -264,6 +264,14 @@ private[spark] object SQLConf { defaultValue = Some(true), doc = "") + val PARQUET_FOLLOW_PARQUET_FORMAT_SPEC = booleanConf( + key = "spark.sql.parquet.followParquetFormatSpec", + defaultValue = Some(false), + doc = "Wether to stick to Parquet format specification when converting Parquet schema to " + + "Spark SQL schema and vice versa. Sticks to the specification if set to true; falls back " + + "to compatible mode if set to false.", + isPublic = false) + val PARQUET_OUTPUT_COMMITTER_CLASS = stringConf( key = "spark.sql.parquet.output.committer.class", defaultValue = Some(classOf[ParquetOutputCommitter].getName), @@ -498,6 +506,12 @@ private[sql] class SQLConf extends Serializable with CatalystConf { */ private[spark] def isParquetINT96AsTimestamp: Boolean = getConf(PARQUET_INT96_AS_TIMESTAMP) + /** + * When set to true, sticks to Parquet format spec when converting Parquet schema to Spark SQL + * schema and vice versa. Otherwise, falls back to compatible mode. + */ + private[spark] def followParquetFormatSpec: Boolean = getConf(PARQUET_FOLLOW_PARQUET_FORMAT_SPEC) + /** * When set to true, partition pruning for in-memory columnar tables is enabled. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala new file mode 100644 index 0000000000000..4fd3e93b70311 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -0,0 +1,565 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.parquet + +import scala.collection.JavaConversions._ + +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.schema.OriginalType._ +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ +import org.apache.parquet.schema.Type.Repetition._ +import org.apache.parquet.schema._ + +import org.apache.spark.sql.types._ +import org.apache.spark.sql.{AnalysisException, SQLConf} + +/** + * This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]] and + * vice versa. + * + * Parquet format backwards-compatibility rules are respected when converting Parquet + * [[MessageType]] schemas. + * + * @see https://github.com/apache/parquet-format/blob/master/LogicalTypes.md + * + * @constructor + * @param assumeBinaryIsString Whether unannotated BINARY fields should be assumed to be Spark SQL + * [[StringType]] fields when converting Parquet a [[MessageType]] to Spark SQL + * [[StructType]]. + * @param assumeInt96IsTimestamp Whether unannotated INT96 fields should be assumed to be Spark SQL + * [[TimestampType]] fields when converting Parquet a [[MessageType]] to Spark SQL + * [[StructType]]. Note that Spark SQL [[TimestampType]] is similar to Hive timestamp, which + * has optional nanosecond precision, but different from `TIME_MILLS` and `TIMESTAMP_MILLIS` + * described in Parquet format spec. + * @param followParquetFormatSpec Whether to generate standard DECIMAL, LIST, and MAP structure when + * converting Spark SQL [[StructType]] to Parquet [[MessageType]]. For Spark 1.4.x and + * prior versions, Spark SQL only supports decimals with a max precision of 18 digits, and + * uses non-standard LIST and MAP structure. Note that the current Parquet format spec is + * backwards-compatible with these settings. If this argument is set to `false`, we fallback + * to old style non-standard behaviors. + */ +private[parquet] class CatalystSchemaConverter( + private val assumeBinaryIsString: Boolean, + private val assumeInt96IsTimestamp: Boolean, + private val followParquetFormatSpec: Boolean) { + + // Only used when constructing converter for converting Spark SQL schema to Parquet schema, in + // which case `assumeInt96IsTimestamp` and `assumeBinaryIsString` are irrelevant. + def this() = this( + assumeBinaryIsString = SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get, + assumeInt96IsTimestamp = SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get, + followParquetFormatSpec = SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get) + + def this(conf: SQLConf) = this( + assumeBinaryIsString = conf.isParquetBinaryAsString, + assumeInt96IsTimestamp = conf.isParquetINT96AsTimestamp, + followParquetFormatSpec = conf.followParquetFormatSpec) + + def this(conf: Configuration) = this( + assumeBinaryIsString = + conf.getBoolean( + SQLConf.PARQUET_BINARY_AS_STRING.key, + SQLConf.PARQUET_BINARY_AS_STRING.defaultValue.get), + assumeInt96IsTimestamp = + conf.getBoolean( + SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, + SQLConf.PARQUET_INT96_AS_TIMESTAMP.defaultValue.get), + followParquetFormatSpec = + conf.getBoolean( + SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key, + SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.defaultValue.get)) + + /** + * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. + */ + def convert(parquetSchema: MessageType): StructType = convert(parquetSchema.asGroupType()) + + private def convert(parquetSchema: GroupType): StructType = { + val fields = parquetSchema.getFields.map { field => + field.getRepetition match { + case OPTIONAL => + StructField(field.getName, convertField(field), nullable = true) + + case REQUIRED => + StructField(field.getName, convertField(field), nullable = false) + + case REPEATED => + throw new AnalysisException( + s"REPEATED not supported outside LIST or MAP. Type: $field") + } + } + + StructType(fields) + } + + /** + * Converts a Parquet [[Type]] to a Spark SQL [[DataType]]. + */ + def convertField(parquetType: Type): DataType = parquetType match { + case t: PrimitiveType => convertPrimitiveField(t) + case t: GroupType => convertGroupField(t.asGroupType()) + } + + private def convertPrimitiveField(field: PrimitiveType): DataType = { + val typeName = field.getPrimitiveTypeName + val originalType = field.getOriginalType + + def typeString = + if (originalType == null) s"$typeName" else s"$typeName ($originalType)" + + def typeNotImplemented() = + throw new AnalysisException(s"Parquet type not yet supported: $typeString") + + def illegalType() = + throw new AnalysisException(s"Illegal Parquet type: $typeString") + + // When maxPrecision = -1, we skip precision range check, and always respect the precision + // specified in field.getDecimalMetadata. This is useful when interpreting decimal types stored + // as binaries with variable lengths. + def makeDecimalType(maxPrecision: Int = -1): DecimalType = { + val precision = field.getDecimalMetadata.getPrecision + val scale = field.getDecimalMetadata.getScale + + CatalystSchemaConverter.analysisRequire( + maxPrecision == -1 || 1 <= precision && precision <= maxPrecision, + s"Invalid decimal precision: $typeName cannot store $precision digits (max $maxPrecision)") + + DecimalType(precision, scale) + } + + field.getPrimitiveTypeName match { + case BOOLEAN => BooleanType + + case FLOAT => FloatType + + case DOUBLE => DoubleType + + case INT32 => + field.getOriginalType match { + case INT_8 => ByteType + case INT_16 => ShortType + case INT_32 | null => IntegerType + case DATE => DateType + case DECIMAL => makeDecimalType(maxPrecisionForBytes(4)) + case TIME_MILLIS => typeNotImplemented() + case _ => illegalType() + } + + case INT64 => + field.getOriginalType match { + case INT_64 | null => LongType + case DECIMAL => makeDecimalType(maxPrecisionForBytes(8)) + case TIMESTAMP_MILLIS => typeNotImplemented() + case _ => illegalType() + } + + case INT96 => + CatalystSchemaConverter.analysisRequire( + assumeInt96IsTimestamp, + "INT96 is not supported unless it's interpreted as timestamp. " + + s"Please try to set ${SQLConf.PARQUET_INT96_AS_TIMESTAMP.key} to true.") + TimestampType + + case BINARY => + field.getOriginalType match { + case UTF8 => StringType + case null if assumeBinaryIsString => StringType + case null => BinaryType + case DECIMAL => makeDecimalType() + case _ => illegalType() + } + + case FIXED_LEN_BYTE_ARRAY => + field.getOriginalType match { + case DECIMAL => makeDecimalType(maxPrecisionForBytes(field.getTypeLength)) + case INTERVAL => typeNotImplemented() + case _ => illegalType() + } + + case _ => illegalType() + } + } + + private def convertGroupField(field: GroupType): DataType = { + Option(field.getOriginalType).fold(convert(field): DataType) { + // A Parquet list is represented as a 3-level structure: + // + // group (LIST) { + // repeated group list { + // element; + // } + // } + // + // However, according to the most recent Parquet format spec (not released yet up until + // writing), some 2-level structures are also recognized for backwards-compatibility. Thus, + // we need to check whether the 2nd level or the 3rd level refers to list element type. + // + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#lists + case LIST => + CatalystSchemaConverter.analysisRequire( + field.getFieldCount == 1, s"Invalid list type $field") + + val repeatedType = field.getType(0) + CatalystSchemaConverter.analysisRequire( + repeatedType.isRepetition(REPEATED), s"Invalid list type $field") + + if (isElementType(repeatedType, field.getName)) { + ArrayType(convertField(repeatedType), containsNull = false) + } else { + val elementType = repeatedType.asGroupType().getType(0) + val optional = elementType.isRepetition(OPTIONAL) + ArrayType(convertField(elementType), containsNull = optional) + } + + // scalastyle:off + // `MAP_KEY_VALUE` is for backwards-compatibility + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules-1 + // scalastyle:on + case MAP | MAP_KEY_VALUE => + CatalystSchemaConverter.analysisRequire( + field.getFieldCount == 1 && !field.getType(0).isPrimitive, + s"Invalid map type: $field") + + val keyValueType = field.getType(0).asGroupType() + CatalystSchemaConverter.analysisRequire( + keyValueType.isRepetition(REPEATED) && keyValueType.getFieldCount == 2, + s"Invalid map type: $field") + + val keyType = keyValueType.getType(0) + CatalystSchemaConverter.analysisRequire( + keyType.isPrimitive, + s"Map key type is expected to be a primitive type, but found: $keyType") + + val valueType = keyValueType.getType(1) + val valueOptional = valueType.isRepetition(OPTIONAL) + MapType( + convertField(keyType), + convertField(valueType), + valueContainsNull = valueOptional) + + case _ => + throw new AnalysisException(s"Unrecognized Parquet type: $field") + } + } + + // scalastyle:off + // Here we implement Parquet LIST backwards-compatibility rules. + // See: https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#backward-compatibility-rules + // scalastyle:on + private def isElementType(repeatedType: Type, parentName: String) = { + { + // For legacy 2-level list types with primitive element type, e.g.: + // + // // List (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated int32 element; + // } + // + repeatedType.isPrimitive + } || { + // For legacy 2-level list types whose element type is a group type with 2 or more fields, + // e.g.: + // + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group element { + // required binary str (UTF8); + // required int32 num; + // }; + // } + // + repeatedType.asGroupType().getFieldCount > 1 + } || { + // For legacy 2-level list types generated by parquet-avro (Parquet version < 1.6.0), e.g.: + // + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group array { + // required binary str (UTF8); + // }; + // } + // + repeatedType.getName == "array" + } || { + // For Parquet data generated by parquet-thrift, e.g.: + // + // // List> (nullable list, non-null elements) + // optional group my_list (LIST) { + // repeated group my_list_tuple { + // required binary str (UTF8); + // }; + // } + // + repeatedType.getName == s"${parentName}_tuple" + } + } + + /** + * Converts a Spark SQL [[StructType]] to a Parquet [[MessageType]]. + */ + def convert(catalystSchema: StructType): MessageType = { + Types.buildMessage().addFields(catalystSchema.map(convertField): _*).named("root") + } + + /** + * Converts a Spark SQL [[StructField]] to a Parquet [[Type]]. + */ + def convertField(field: StructField): Type = { + convertField(field, if (field.nullable) OPTIONAL else REQUIRED) + } + + private def convertField(field: StructField, repetition: Type.Repetition): Type = { + CatalystSchemaConverter.checkFieldName(field.name) + + field.dataType match { + // =================== + // Simple atomic types + // =================== + + case BooleanType => + Types.primitive(BOOLEAN, repetition).named(field.name) + + case ByteType => + Types.primitive(INT32, repetition).as(INT_8).named(field.name) + + case ShortType => + Types.primitive(INT32, repetition).as(INT_16).named(field.name) + + case IntegerType => + Types.primitive(INT32, repetition).named(field.name) + + case LongType => + Types.primitive(INT64, repetition).named(field.name) + + case FloatType => + Types.primitive(FLOAT, repetition).named(field.name) + + case DoubleType => + Types.primitive(DOUBLE, repetition).named(field.name) + + case StringType => + Types.primitive(BINARY, repetition).as(UTF8).named(field.name) + + case DateType => + Types.primitive(INT32, repetition).as(DATE).named(field.name) + + // NOTE: !! This timestamp type is not specified in Parquet format spec !! + // However, Impala and older versions of Spark SQL use INT96 to store timestamps with + // nanosecond precision (not TIME_MILLIS or TIMESTAMP_MILLIS described in the spec). + case TimestampType => + Types.primitive(INT96, repetition).named(field.name) + + case BinaryType => + Types.primitive(BINARY, repetition).named(field.name) + + // ===================================== + // Decimals (for Spark version <= 1.4.x) + // ===================================== + + // Spark 1.4.x and prior versions only support decimals with a maximum precision of 18 and + // always store decimals in fixed-length byte arrays. + case DecimalType.Fixed(precision, scale) + if precision <= maxPrecisionForBytes(8) && !followParquetFormatSpec => + Types + .primitive(FIXED_LEN_BYTE_ARRAY, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .length(minBytesForPrecision(precision)) + .named(field.name) + + case dec @ DecimalType() if !followParquetFormatSpec => + throw new AnalysisException( + s"Data type $dec is not supported. " + + s"When ${SQLConf.PARQUET_FOLLOW_PARQUET_FORMAT_SPEC.key} is set to false," + + "decimal precision and scale must be specified, " + + "and precision must be less than or equal to 18.") + + // ===================================== + // Decimals (follow Parquet format spec) + // ===================================== + + // Uses INT32 for 1 <= precision <= 9 + case DecimalType.Fixed(precision, scale) + if precision <= maxPrecisionForBytes(4) && followParquetFormatSpec => + Types + .primitive(INT32, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .named(field.name) + + // Uses INT64 for 1 <= precision <= 18 + case DecimalType.Fixed(precision, scale) + if precision <= maxPrecisionForBytes(8) && followParquetFormatSpec => + Types + .primitive(INT64, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .named(field.name) + + // Uses FIXED_LEN_BYTE_ARRAY for all other precisions + case DecimalType.Fixed(precision, scale) if followParquetFormatSpec => + Types + .primitive(FIXED_LEN_BYTE_ARRAY, repetition) + .as(DECIMAL) + .precision(precision) + .scale(scale) + .length(minBytesForPrecision(precision)) + .named(field.name) + + case dec @ DecimalType.Unlimited if followParquetFormatSpec => + throw new AnalysisException( + s"Data type $dec is not supported. Decimal precision and scale must be specified.") + + // =================================================== + // ArrayType and MapType (for Spark versions <= 1.4.x) + // =================================================== + + // Spark 1.4.x and prior versions convert ArrayType with nullable elements into a 3-level + // LIST structure. This behavior mimics parquet-hive (1.6.0rc3). Note that this case is + // covered by the backwards-compatibility rules implemented in `isElementType()`. + case ArrayType(elementType, nullable @ true) if !followParquetFormatSpec => + // group (LIST) { + // optional group bag { + // repeated element; + // } + // } + ConversionPatterns.listType( + repetition, + field.name, + Types + .buildGroup(REPEATED) + .addField(convertField(StructField("element", elementType, nullable))) + .named(CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME)) + + // Spark 1.4.x and prior versions convert ArrayType with non-nullable elements into a 2-level + // LIST structure. This behavior mimics parquet-avro (1.6.0rc3). Note that this case is + // covered by the backwards-compatibility rules implemented in `isElementType()`. + case ArrayType(elementType, nullable @ false) if !followParquetFormatSpec => + // group (LIST) { + // repeated element; + // } + ConversionPatterns.listType( + repetition, + field.name, + convertField(StructField("element", elementType, nullable), REPEATED)) + + // Spark 1.4.x and prior versions convert MapType into a 3-level group annotated by + // MAP_KEY_VALUE. This is covered by `convertGroupField(field: GroupType): DataType`. + case MapType(keyType, valueType, valueContainsNull) if !followParquetFormatSpec => + // group (MAP) { + // repeated group map (MAP_KEY_VALUE) { + // required key; + // value; + // } + // } + ConversionPatterns.mapType( + repetition, + field.name, + convertField(StructField("key", keyType, nullable = false)), + convertField(StructField("value", valueType, valueContainsNull))) + + // ================================================== + // ArrayType and MapType (follow Parquet format spec) + // ================================================== + + case ArrayType(elementType, containsNull) if followParquetFormatSpec => + // group (LIST) { + // repeated group list { + // element; + // } + // } + Types + .buildGroup(repetition).as(LIST) + .addField( + Types.repeatedGroup() + .addField(convertField(StructField("element", elementType, containsNull))) + .named("list")) + .named(field.name) + + case MapType(keyType, valueType, valueContainsNull) => + // group (MAP) { + // repeated group key_value { + // required key; + // value; + // } + // } + Types + .buildGroup(repetition).as(MAP) + .addField( + Types + .repeatedGroup() + .addField(convertField(StructField("key", keyType, nullable = false))) + .addField(convertField(StructField("value", valueType, valueContainsNull))) + .named("key_value")) + .named(field.name) + + // =========== + // Other types + // =========== + + case StructType(fields) => + fields.foldLeft(Types.buildGroup(repetition)) { (builder, field) => + builder.addField(convertField(field)) + }.named(field.name) + + case udt: UserDefinedType[_] => + convertField(field.copy(dataType = udt.sqlType)) + + case _ => + throw new AnalysisException(s"Unsupported data type $field.dataType") + } + } + + // Max precision of a decimal value stored in `numBytes` bytes + private def maxPrecisionForBytes(numBytes: Int): Int = { + Math.round( // convert double to long + Math.floor(Math.log10( // number of base-10 digits + Math.pow(2, 8 * numBytes - 1) - 1))) // max value stored in numBytes + .asInstanceOf[Int] + } + + // Min byte counts needed to store decimals with various precisions + private val minBytesForPrecision: Array[Int] = Array.tabulate(38) { precision => + var numBytes = 1 + while (math.pow(2.0, 8 * numBytes - 1) < math.pow(10.0, precision)) { + numBytes += 1 + } + numBytes + } +} + + +private[parquet] object CatalystSchemaConverter { + def checkFieldName(name: String): Unit = { + // ,;{}()\n\t= and space are special characters in Parquet schema + analysisRequire( + !name.matches(".*[ ,;{}()\n\t=].*"), + s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". + |Please use alias to rename it. + """.stripMargin.split("\n").mkString(" ")) + } + + def analysisRequire(f: => Boolean, message: String): Unit = { + if (!f) { + throw new AnalysisException(message) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index e65fa0030e179..0d96a1e8070b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -86,8 +86,7 @@ private[parquet] class RowReadSupport extends ReadSupport[InternalRow] with Logg // TODO: Why it can be null? if (schema == null) { log.debug("falling back to Parquet read schema") - schema = ParquetTypesConverter.convertToAttributes( - parquetSchema, false, true) + schema = ParquetTypesConverter.convertToAttributes(parquetSchema, false, true) } log.debug(s"list of attributes that will be read: $schema") new RowRecordMaterializer(parquetSchema, schema) @@ -105,8 +104,7 @@ private[parquet] class RowReadSupport extends ReadSupport[InternalRow] with Logg // If the parquet file is thrift derived, there is a good chance that // it will have the thrift class in metadata. val isThriftDerived = keyValueMetaData.keySet().contains("thrift.class") - parquetSchema = ParquetTypesConverter - .convertFromAttributes(requestedAttributes, isThriftDerived) + parquetSchema = ParquetTypesConverter.convertFromAttributes(requestedAttributes) metadata.put( RowReadSupport.SPARK_ROW_REQUESTED_SCHEMA, ParquetTypesConverter.convertToString(requestedAttributes)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index ba2a35b74ef82..4d5199a140344 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -29,214 +29,19 @@ import org.apache.parquet.format.converter.ParquetMetadataConverter import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} -import org.apache.parquet.schema.PrimitiveType.{PrimitiveTypeName => ParquetPrimitiveTypeName} -import org.apache.parquet.schema.Type.Repetition -import org.apache.parquet.schema.{ConversionPatterns, DecimalMetadata, GroupType => ParquetGroupType, MessageType, OriginalType => ParquetOriginalType, PrimitiveType => ParquetPrimitiveType, Type => ParquetType, Types => ParquetTypes} +import org.apache.parquet.schema.MessageType import org.apache.spark.Logging -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.types._ -/** A class representing Parquet info fields we care about, for passing back to Parquet */ -private[parquet] case class ParquetTypeInfo( - primitiveType: ParquetPrimitiveTypeName, - originalType: Option[ParquetOriginalType] = None, - decimalMetadata: Option[DecimalMetadata] = None, - length: Option[Int] = None) - private[parquet] object ParquetTypesConverter extends Logging { def isPrimitiveType(ctype: DataType): Boolean = ctype match { case _: NumericType | BooleanType | StringType | BinaryType => true case _: DataType => false } - def toPrimitiveDataType( - parquetType: ParquetPrimitiveType, - binaryAsString: Boolean, - int96AsTimestamp: Boolean): DataType = { - val originalType = parquetType.getOriginalType - val decimalInfo = parquetType.getDecimalMetadata - parquetType.getPrimitiveTypeName match { - case ParquetPrimitiveTypeName.BINARY - if (originalType == ParquetOriginalType.UTF8 || binaryAsString) => StringType - case ParquetPrimitiveTypeName.BINARY => BinaryType - case ParquetPrimitiveTypeName.BOOLEAN => BooleanType - case ParquetPrimitiveTypeName.DOUBLE => DoubleType - case ParquetPrimitiveTypeName.FLOAT => FloatType - case ParquetPrimitiveTypeName.INT32 - if originalType == ParquetOriginalType.DATE => DateType - case ParquetPrimitiveTypeName.INT32 => IntegerType - case ParquetPrimitiveTypeName.INT64 => LongType - case ParquetPrimitiveTypeName.INT96 if int96AsTimestamp => TimestampType - case ParquetPrimitiveTypeName.INT96 => - // TODO: add BigInteger type? TODO(andre) use DecimalType instead???? - throw new AnalysisException("Potential loss of precision: cannot convert INT96") - case ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY - if (originalType == ParquetOriginalType.DECIMAL && decimalInfo.getPrecision <= 18) => - // TODO: for now, our reader only supports decimals that fit in a Long - DecimalType(decimalInfo.getPrecision, decimalInfo.getScale) - case _ => throw new AnalysisException(s"Unsupported parquet datatype $parquetType") - } - } - - /** - * Converts a given Parquet `Type` into the corresponding - * [[org.apache.spark.sql.types.DataType]]. - * - * We apply the following conversion rules: - *
    - *
  • Primitive types are converter to the corresponding primitive type.
  • - *
  • Group types that have a single field that is itself a group, which has repetition - * level `REPEATED`, are treated as follows:
      - *
    • If the nested group has name `values`, the surrounding group is converted - * into an [[ArrayType]] with the corresponding field type (primitive or - * complex) as element type.
    • - *
    • If the nested group has name `map` and two fields (named `key` and `value`), - * the surrounding group is converted into a [[MapType]] - * with the corresponding key and value (value possibly complex) types. - * Note that we currently assume map values are not nullable.
    • - *
    • Other group types are converted into a [[StructType]] with the corresponding - * field types.
  • - *
- * Note that fields are determined to be `nullable` if and only if their Parquet repetition - * level is not `REQUIRED`. - * - * @param parquetType The type to convert. - * @return The corresponding Catalyst type. - */ - def toDataType(parquetType: ParquetType, - isBinaryAsString: Boolean, - isInt96AsTimestamp: Boolean): DataType = { - def correspondsToMap(groupType: ParquetGroupType): Boolean = { - if (groupType.getFieldCount != 1 || groupType.getFields.apply(0).isPrimitive) { - false - } else { - // This mostly follows the convention in ``parquet.schema.ConversionPatterns`` - val keyValueGroup = groupType.getFields.apply(0).asGroupType() - keyValueGroup.getRepetition == Repetition.REPEATED && - keyValueGroup.getName == CatalystConverter.MAP_SCHEMA_NAME && - keyValueGroup.getFieldCount == 2 && - keyValueGroup.getFields.apply(0).getName == CatalystConverter.MAP_KEY_SCHEMA_NAME && - keyValueGroup.getFields.apply(1).getName == CatalystConverter.MAP_VALUE_SCHEMA_NAME - } - } - - def correspondsToArray(groupType: ParquetGroupType): Boolean = { - groupType.getFieldCount == 1 && - groupType.getFieldName(0) == CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME && - groupType.getFields.apply(0).getRepetition == Repetition.REPEATED - } - - if (parquetType.isPrimitive) { - toPrimitiveDataType(parquetType.asPrimitiveType, isBinaryAsString, isInt96AsTimestamp) - } else { - val groupType = parquetType.asGroupType() - parquetType.getOriginalType match { - // if the schema was constructed programmatically there may be hints how to convert - // it inside the metadata via the OriginalType field - case ParquetOriginalType.LIST => { // TODO: check enums! - assert(groupType.getFieldCount == 1) - val field = groupType.getFields.apply(0) - if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { - val bag = field.asGroupType() - assert(bag.getFieldCount == 1) - ArrayType( - toDataType(bag.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp), - containsNull = true) - } else { - ArrayType( - toDataType(field, isBinaryAsString, isInt96AsTimestamp), containsNull = false) - } - } - case ParquetOriginalType.MAP => { - assert( - !groupType.getFields.apply(0).isPrimitive, - "Parquet Map type malformatted: expected nested group for map!") - val keyValueGroup = groupType.getFields.apply(0).asGroupType() - assert( - keyValueGroup.getFieldCount == 2, - "Parquet Map type malformatted: nested group should have 2 (key, value) fields!") - assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - - val keyType = - toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp) - val valueType = - toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString, isInt96AsTimestamp) - MapType(keyType, valueType, - keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) - } - case _ => { - // Note: the order of these checks is important! - if (correspondsToMap(groupType)) { // MapType - val keyValueGroup = groupType.getFields.apply(0).asGroupType() - assert(keyValueGroup.getFields.apply(0).getRepetition == Repetition.REQUIRED) - - val keyType = - toDataType(keyValueGroup.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp) - val valueType = - toDataType(keyValueGroup.getFields.apply(1), isBinaryAsString, isInt96AsTimestamp) - MapType(keyType, valueType, - keyValueGroup.getFields.apply(1).getRepetition != Repetition.REQUIRED) - } else if (correspondsToArray(groupType)) { // ArrayType - val field = groupType.getFields.apply(0) - if (field.getName == CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME) { - val bag = field.asGroupType() - assert(bag.getFieldCount == 1) - ArrayType( - toDataType(bag.getFields.apply(0), isBinaryAsString, isInt96AsTimestamp), - containsNull = true) - } else { - ArrayType( - toDataType(field, isBinaryAsString, isInt96AsTimestamp), containsNull = false) - } - } else { // everything else: StructType - val fields = groupType - .getFields - .map(ptype => new StructField( - ptype.getName, - toDataType(ptype, isBinaryAsString, isInt96AsTimestamp), - ptype.getRepetition != Repetition.REQUIRED)) - StructType(fields) - } - } - } - } - } - - /** - * For a given Catalyst [[org.apache.spark.sql.types.DataType]] return - * the name of the corresponding Parquet primitive type or None if the given type - * is not primitive. - * - * @param ctype The type to convert - * @return The name of the corresponding Parquet type properties - */ - def fromPrimitiveDataType(ctype: DataType): Option[ParquetTypeInfo] = ctype match { - case StringType => Some(ParquetTypeInfo( - ParquetPrimitiveTypeName.BINARY, Some(ParquetOriginalType.UTF8))) - case BinaryType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BINARY)) - case BooleanType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.BOOLEAN)) - case DoubleType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.DOUBLE)) - case FloatType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FLOAT)) - case IntegerType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) - // There is no type for Byte or Short so we promote them to INT32. - case ShortType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) - case ByteType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT32)) - case DateType => Some(ParquetTypeInfo( - ParquetPrimitiveTypeName.INT32, Some(ParquetOriginalType.DATE))) - case LongType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT64)) - case TimestampType => Some(ParquetTypeInfo(ParquetPrimitiveTypeName.INT96)) - case DecimalType.Fixed(precision, scale) if precision <= 18 => - // TODO: for now, our writer only supports decimals that fit in a Long - Some(ParquetTypeInfo(ParquetPrimitiveTypeName.FIXED_LEN_BYTE_ARRAY, - Some(ParquetOriginalType.DECIMAL), - Some(new DecimalMetadata(precision, scale)), - Some(BYTES_FOR_PRECISION(precision)))) - case _ => None - } - /** * Compute the FIXED_LEN_BYTE_ARRAY length needed to represent a given DECIMAL precision. */ @@ -248,177 +53,29 @@ private[parquet] object ParquetTypesConverter extends Logging { length } - /** - * Converts a given Catalyst [[org.apache.spark.sql.types.DataType]] into - * the corresponding Parquet `Type`. - * - * The conversion follows the rules below: - *
    - *
  • Primitive types are converted into Parquet's primitive types.
  • - *
  • [[org.apache.spark.sql.types.StructType]]s are converted - * into Parquet's `GroupType` with the corresponding field types.
  • - *
  • [[org.apache.spark.sql.types.ArrayType]]s are converted - * into a 2-level nested group, where the outer group has the inner - * group as sole field. The inner group has name `values` and - * repetition level `REPEATED` and has the element type of - * the array as schema. We use Parquet's `ConversionPatterns` for this - * purpose.
  • - *
  • [[org.apache.spark.sql.types.MapType]]s are converted - * into a nested (2-level) Parquet `GroupType` with two fields: a key - * type and a value type. The nested group has repetition level - * `REPEATED` and name `map`. We use Parquet's `ConversionPatterns` - * for this purpose
  • - *
- * Parquet's repetition level is generally set according to the following rule: - *
    - *
  • If the call to `fromDataType` is recursive inside an enclosing `ArrayType` or - * `MapType`, then the repetition level is set to `REPEATED`.
  • - *
  • Otherwise, if the attribute whose type is converted is `nullable`, the Parquet - * type gets repetition level `OPTIONAL` and otherwise `REQUIRED`.
  • - *
- * - *@param ctype The type to convert - * @param name The name of the [[org.apache.spark.sql.catalyst.expressions.Attribute]] - * whose type is converted - * @param nullable When true indicates that the attribute is nullable - * @param inArray When true indicates that this is a nested attribute inside an array. - * @return The corresponding Parquet type. - */ - def fromDataType( - ctype: DataType, - name: String, - nullable: Boolean = true, - inArray: Boolean = false, - toThriftSchemaNames: Boolean = false): ParquetType = { - val repetition = - if (inArray) { - Repetition.REPEATED - } else { - if (nullable) Repetition.OPTIONAL else Repetition.REQUIRED - } - val arraySchemaName = if (toThriftSchemaNames) { - name + CatalystConverter.THRIFT_ARRAY_ELEMENTS_SCHEMA_NAME_SUFFIX - } else { - CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME - } - val typeInfo = fromPrimitiveDataType(ctype) - typeInfo.map { - case ParquetTypeInfo(primitiveType, originalType, decimalMetadata, length) => - val builder = ParquetTypes.primitive(primitiveType, repetition).as(originalType.orNull) - for (len <- length) { - builder.length(len) - } - for (metadata <- decimalMetadata) { - builder.precision(metadata.getPrecision).scale(metadata.getScale) - } - builder.named(name) - }.getOrElse { - ctype match { - case udt: UserDefinedType[_] => { - fromDataType(udt.sqlType, name, nullable, inArray, toThriftSchemaNames) - } - case ArrayType(elementType, false) => { - val parquetElementType = fromDataType( - elementType, - arraySchemaName, - nullable = false, - inArray = true, - toThriftSchemaNames) - ConversionPatterns.listType(repetition, name, parquetElementType) - } - case ArrayType(elementType, true) => { - val parquetElementType = fromDataType( - elementType, - arraySchemaName, - nullable = true, - inArray = false, - toThriftSchemaNames) - ConversionPatterns.listType( - repetition, - name, - new ParquetGroupType( - Repetition.REPEATED, - CatalystConverter.ARRAY_CONTAINS_NULL_BAG_SCHEMA_NAME, - parquetElementType)) - } - case StructType(structFields) => { - val fields = structFields.map { - field => fromDataType(field.dataType, field.name, field.nullable, - inArray = false, toThriftSchemaNames) - } - new ParquetGroupType(repetition, name, fields.toSeq) - } - case MapType(keyType, valueType, valueContainsNull) => { - val parquetKeyType = - fromDataType( - keyType, - CatalystConverter.MAP_KEY_SCHEMA_NAME, - nullable = false, - inArray = false, - toThriftSchemaNames) - val parquetValueType = - fromDataType( - valueType, - CatalystConverter.MAP_VALUE_SCHEMA_NAME, - nullable = valueContainsNull, - inArray = false, - toThriftSchemaNames) - ConversionPatterns.mapType( - repetition, - name, - parquetKeyType, - parquetValueType) - } - case _ => throw new AnalysisException(s"Unsupported datatype $ctype") - } - } - } - - def convertToAttributes(parquetSchema: ParquetType, - isBinaryAsString: Boolean, - isInt96AsTimestamp: Boolean): Seq[Attribute] = { - parquetSchema - .asGroupType() - .getFields - .map( - field => - new AttributeReference( - field.getName, - toDataType(field, isBinaryAsString, isInt96AsTimestamp), - field.getRepetition != Repetition.REQUIRED)()) + def convertToAttributes( + parquetSchema: MessageType, + isBinaryAsString: Boolean, + isInt96AsTimestamp: Boolean): Seq[Attribute] = { + val converter = new CatalystSchemaConverter( + isBinaryAsString, isInt96AsTimestamp, followParquetFormatSpec = false) + converter.convert(parquetSchema).toAttributes } - def convertFromAttributes(attributes: Seq[Attribute], - toThriftSchemaNames: Boolean = false): MessageType = { - checkSpecialCharacters(attributes) - val fields = attributes.map( - attribute => - fromDataType(attribute.dataType, attribute.name, attribute.nullable, - toThriftSchemaNames = toThriftSchemaNames)) - new MessageType("root", fields) + def convertFromAttributes(attributes: Seq[Attribute]): MessageType = { + val converter = new CatalystSchemaConverter() + converter.convert(StructType.fromAttributes(attributes)) } def convertFromString(string: String): Seq[Attribute] = { Try(DataType.fromJson(string)).getOrElse(DataType.fromCaseClassString(string)) match { case s: StructType => s.toAttributes - case other => throw new AnalysisException(s"Can convert $string to row") - } - } - - private def checkSpecialCharacters(schema: Seq[Attribute]) = { - // ,;{}()\n\t= and space character are special characters in Parquet schema - schema.map(_.name).foreach { name => - if (name.matches(".*[ ,;{}()\n\t=].*")) { - throw new AnalysisException( - s"""Attribute name "$name" contains invalid character(s) among " ,;{}()\\n\\t=". - |Please use alias to rename it. - """.stripMargin.split("\n").mkString(" ")) - } + case other => sys.error(s"Can convert $string to row") } } def convertToString(schema: Seq[Attribute]): String = { - checkSpecialCharacters(schema) + schema.map(_.name).foreach(CatalystSchemaConverter.checkFieldName) StructType.fromAttributes(schema).json } @@ -450,8 +107,7 @@ private[parquet] object ParquetTypesConverter extends Logging { ParquetTypesConverter.convertToString(attributes)) // TODO: add extra data, e.g., table name, date, etc.? - val parquetSchema: MessageType = - ParquetTypesConverter.convertFromAttributes(attributes) + val parquetSchema: MessageType = ParquetTypesConverter.convertFromAttributes(attributes) val metaData: FileMetaData = new FileMetaData( parquetSchema, extraMetadata, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index 47a7be1c6a664..7b16eba00d6fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -99,7 +99,6 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { } test("fixed-length decimals") { - def makeDecimalRDD(decimal: DecimalType): DataFrame = sqlContext.sparkContext .parallelize(0 to 1000) @@ -158,6 +157,11 @@ class ParquetIOSuiteBase extends QueryTest with ParquetTest { checkParquetFile(data) } + test("array and double") { + val data = (1 to 4).map(i => (i.toDouble, Seq(i.toDouble, (i + 1).toDouble))) + checkParquetFile(data) + } + test("struct") { val data = (1 to 4).map(i => Tuple1((i, s"val_$i"))) withParquetDataFrame(data) { df => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index 171a656f0e01e..d0bfcde7e032b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -24,26 +24,109 @@ import org.apache.parquet.schema.MessageTypeParser import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.types._ -class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { - lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext +abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest { + val sqlContext = TestSQLContext /** * Checks whether the reflected Parquet message type for product type `T` conforms `messageType`. */ - private def testSchema[T <: Product: ClassTag: TypeTag]( - testName: String, messageType: String, isThriftDerived: Boolean = false): Unit = { - test(testName) { - val actual = ParquetTypesConverter.convertFromAttributes( - ScalaReflection.attributesFor[T], isThriftDerived) - val expected = MessageTypeParser.parseMessageType(messageType) + protected def testSchemaInference[T <: Product: ClassTag: TypeTag]( + testName: String, + messageType: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + testSchema( + testName, + StructType.fromAttributes(ScalaReflection.attributesFor[T]), + messageType, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + } + + protected def testParquetToCatalyst( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + val converter = new CatalystSchemaConverter( + assumeBinaryIsString = binaryAsString, + assumeInt96IsTimestamp = int96AsTimestamp, + followParquetFormatSpec = followParquetFormatSpec) + + test(s"sql <= parquet: $testName") { + val actual = converter.convert(MessageTypeParser.parseMessageType(parquetSchema)) + val expected = sqlSchema + assert( + actual === expected, + s"""Schema mismatch. + |Expected schema: ${expected.json} + |Actual schema: ${actual.json} + """.stripMargin) + } + } + + protected def testCatalystToParquet( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + val converter = new CatalystSchemaConverter( + assumeBinaryIsString = binaryAsString, + assumeInt96IsTimestamp = int96AsTimestamp, + followParquetFormatSpec = followParquetFormatSpec) + + test(s"sql => parquet: $testName") { + val actual = converter.convert(sqlSchema) + val expected = MessageTypeParser.parseMessageType(parquetSchema) actual.checkContains(expected) expected.checkContains(actual) } } - testSchema[(Boolean, Int, Long, Float, Double, Array[Byte])]( + protected def testSchema( + testName: String, + sqlSchema: StructType, + parquetSchema: String, + binaryAsString: Boolean = true, + int96AsTimestamp: Boolean = true, + followParquetFormatSpec: Boolean = false, + isThriftDerived: Boolean = false): Unit = { + + testCatalystToParquet( + testName, + sqlSchema, + parquetSchema, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + + testParquetToCatalyst( + testName, + sqlSchema, + parquetSchema, + binaryAsString, + int96AsTimestamp, + followParquetFormatSpec, + isThriftDerived) + } +} + +class ParquetSchemaInferenceSuite extends ParquetSchemaTest { + testSchemaInference[(Boolean, Int, Long, Float, Double, Array[Byte])]( "basic types", """ |message root { @@ -54,9 +137,10 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { | required double _5; | optional binary _6; |} - """.stripMargin) + """.stripMargin, + binaryAsString = false) - testSchema[(Byte, Short, Int, Long, java.sql.Date)]( + testSchemaInference[(Byte, Short, Int, Long, java.sql.Date)]( "logical integral types", """ |message root { @@ -68,27 +152,79 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { |} """.stripMargin) - // Currently String is the only supported logical binary type. - testSchema[Tuple1[String]]( - "binary logical types", + testSchemaInference[Tuple1[String]]( + "string", """ |message root { | optional binary _1 (UTF8); |} + """.stripMargin, + binaryAsString = true) + + testSchemaInference[Tuple1[Seq[Int]]]( + "non-nullable array - non-standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated int32 element; + | } + |} """.stripMargin) - testSchema[Tuple1[Seq[Int]]]( - "array", + testSchemaInference[Tuple1[Seq[Int]]]( + "non-nullable array - standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Seq[Integer]]]( + "nullable array - non-standard", """ |message root { | optional group _1 (LIST) { - | repeated int32 array; + | repeated group bag { + | optional int32 element; + | } | } |} """.stripMargin) - testSchema[Tuple1[Map[Int, String]]]( - "map", + testSchemaInference[Tuple1[Seq[Integer]]]( + "nullable array - standard", + """ + |message root { + | optional group _1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Map[Int, String]]]( + "map - standard", + """ + |message root { + | optional group _1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchemaInference[Tuple1[Map[Int, String]]]( + "map - non-standard", """ |message root { | optional group _1 (MAP) { @@ -100,7 +236,7 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { |} """.stripMargin) - testSchema[Tuple1[Pair[Int, String]]]( + testSchemaInference[Tuple1[Pair[Int, String]]]( "struct", """ |message root { @@ -109,20 +245,21 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { | optional binary _2 (UTF8); | } |} - """.stripMargin) + """.stripMargin, + followParquetFormatSpec = true) - testSchema[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( - "deeply nested type", + testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( + "deeply nested type - non-standard", """ |message root { - | optional group _1 (MAP) { - | repeated group map (MAP_KEY_VALUE) { + | optional group _1 (MAP_KEY_VALUE) { + | repeated group map { | required int32 key; | optional group value { | optional binary _1 (UTF8); | optional group _2 (LIST) { | repeated group bag { - | optional group array { + | optional group element { | required int32 _1; | required double _2; | } @@ -134,43 +271,76 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { |} """.stripMargin) - testSchema[(Option[Int], Map[Int, Option[Double]])]( - "optional types", + testSchemaInference[Tuple1[Map[Int, (String, Seq[(Int, Double)])]]]( + "deeply nested type - standard", """ |message root { - | optional int32 _1; - | optional group _2 (MAP) { - | repeated group map (MAP_KEY_VALUE) { + | optional group _1 (MAP) { + | repeated group key_value { | required int32 key; - | optional double value; + | optional group value { + | optional binary _1 (UTF8); + | optional group _2 (LIST) { + | repeated group list { + | optional group element { + | required int32 _1; + | required double _2; + | } + | } + | } + | } | } | } |} - """.stripMargin) + """.stripMargin, + followParquetFormatSpec = true) - // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated - // as expected from attributes - testSchema[(Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])]( - "thrift generated parquet schema", + testSchemaInference[(Option[Int], Map[Int, Option[Double]])]( + "optional types", """ |message root { - | optional binary _1 (UTF8); - | optional binary _2 (UTF8); - | optional binary _3 (UTF8); - | optional group _4 (LIST) { - | repeated int32 _4_tuple; - | } - | optional group _5 (MAP) { - | repeated group map (MAP_KEY_VALUE) { - | required binary key (UTF8); - | optional group value (LIST) { - | repeated int32 value_tuple; - | } + | optional int32 _1; + | optional group _2 (MAP) { + | repeated group key_value { + | required int32 key; + | optional double value; | } | } |} - """.stripMargin, isThriftDerived = true) + """.stripMargin, + followParquetFormatSpec = true) + // Parquet files generated by parquet-thrift are already handled by the schema converter, but + // let's leave this test here until both read path and write path are all updated. + ignore("thrift generated parquet schema") { + // Test for SPARK-4520 -- ensure that thrift generated parquet schema is generated + // as expected from attributes + testSchemaInference[( + Array[Byte], Array[Byte], Array[Byte], Seq[Int], Map[Array[Byte], Seq[Int]])]( + "thrift generated parquet schema", + """ + |message root { + | optional binary _1 (UTF8); + | optional binary _2 (UTF8); + | optional binary _3 (UTF8); + | optional group _4 (LIST) { + | repeated int32 _4_tuple; + | } + | optional group _5 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required binary key (UTF8); + | optional group value (LIST) { + | repeated int32 value_tuple; + | } + | } + | } + |} + """.stripMargin, + isThriftDerived = true) + } +} + +class ParquetSchemaSuite extends ParquetSchemaTest { test("DataType string parser compatibility") { // This is the generated string from previous versions of the Spark SQL, using the following: // val schema = StructType(List( @@ -180,10 +350,7 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { "StructType(List(StructField(c1,IntegerType,false), StructField(c2,BinaryType,true)))" // scalastyle:off - val jsonString = - """ - |{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]} - """.stripMargin + val jsonString = """{"type":"struct","fields":[{"name":"c1","type":"integer","nullable":false,"metadata":{}},{"name":"c2","type":"binary","nullable":true,"metadata":{}}]}""" // scalastyle:on val fromCaseClassString = ParquetTypesConverter.convertFromString(caseClassString) @@ -277,4 +444,465 @@ class ParquetSchemaSuite extends SparkFunSuite with ParquetTest { StructField("secondField", StringType, nullable = true)))) }.getMessage.contains("detected conflicting schemas")) } + + // ======================================================= + // Tests for converting Parquet LIST to Catalyst ArrayType + // ======================================================= + + testParquetToCatalyst( + "Backwards-compatibility: LIST with nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with nullable element type - 2", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | optional int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 2", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | required int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 3", + StructType(Seq( + StructField("f1", ArrayType(IntegerType, containsNull = false), nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated int32 element; + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 4", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false), + StructField("num", IntegerType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group element { + | required binary str (UTF8); + | required int32 num; + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 5 - parquet-avro style", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group array { + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: LIST with non-nullable element type - 6 - parquet-thrift style", + StructType(Seq( + StructField( + "f1", + ArrayType( + StructType(Seq( + StructField("str", StringType, nullable = false))), + containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group f1_tuple { + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + // ======================================================= + // Tests for converting Catalyst ArrayType to Parquet LIST + // ======================================================= + + testCatalystToParquet( + "Backwards-compatibility: LIST with nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | optional int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: LIST with nullable element type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = true), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group bag { + | optional int32 element; + | } + | } + |} + """.stripMargin) + + testCatalystToParquet( + "Backwards-compatibility: LIST with non-nullable element type - 1 - standard", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated group list { + | required int32 element; + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: LIST with non-nullable element type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + ArrayType(IntegerType, containsNull = false), + nullable = true))), + """message root { + | optional group f1 (LIST) { + | repeated int32 element; + | } + |} + """.stripMargin) + + // ==================================================== + // Tests for converting Parquet Map to Catalyst MapType + // ==================================================== + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 2", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP_KEY_VALUE) { + | repeated group map { + | required int32 num; + | required binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with non-nullable value type - 3 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 2", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP_KEY_VALUE) { + | repeated group map { + | required int32 num; + | optional binary str (UTF8); + | } + | } + |} + """.stripMargin) + + testParquetToCatalyst( + "Backwards-compatibility: MAP with nullable value type - 3 - parquet-avro style", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + // ==================================================== + // Tests for converting Catalyst MapType to Parquet Map + // ==================================================== + + testCatalystToParquet( + "Backwards-compatibility: MAP with non-nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: MAP with non-nullable value type - 2 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = false), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | required binary value (UTF8); + | } + | } + |} + """.stripMargin) + + testCatalystToParquet( + "Backwards-compatibility: MAP with nullable value type - 1 - standard", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group key_value { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin, + followParquetFormatSpec = true) + + testCatalystToParquet( + "Backwards-compatibility: MAP with nullable value type - 3 - prior to 1.4.x", + StructType(Seq( + StructField( + "f1", + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true))), + """message root { + | optional group f1 (MAP) { + | repeated group map (MAP_KEY_VALUE) { + | required int32 key; + | optional binary value (UTF8); + | } + | } + |} + """.stripMargin) + + // ================================= + // Tests for conversion for decimals + // ================================= + + testSchema( + "DECIMAL(1, 0) - standard", + StructType(Seq(StructField("f1", DecimalType(1, 0)))), + """message root { + | optional int32 f1 (DECIMAL(1, 0)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(8, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(8, 3)))), + """message root { + | optional int32 f1 (DECIMAL(8, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(9, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(9, 3)))), + """message root { + | optional int32 f1 (DECIMAL(9, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(18, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(18, 3)))), + """message root { + | optional int64 f1 (DECIMAL(18, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(19, 3) - standard", + StructType(Seq(StructField("f1", DecimalType(19, 3)))), + """message root { + | optional fixed_len_byte_array(9) f1 (DECIMAL(19, 3)); + |} + """.stripMargin, + followParquetFormatSpec = true) + + testSchema( + "DECIMAL(1, 0) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(1, 0)))), + """message root { + | optional fixed_len_byte_array(1) f1 (DECIMAL(1, 0)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(8, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(8, 3)))), + """message root { + | optional fixed_len_byte_array(4) f1 (DECIMAL(8, 3)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(9, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(9, 3)))), + """message root { + | optional fixed_len_byte_array(5) f1 (DECIMAL(9, 3)); + |} + """.stripMargin) + + testSchema( + "DECIMAL(18, 3) - prior to 1.4.x", + StructType(Seq(StructField("f1", DecimalType(18, 3)))), + """message root { + | optional fixed_len_byte_array(8) f1 (DECIMAL(18, 3)); + |} + """.stripMargin) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index a2e666586c186..f0aad8dbbe64d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -638,7 +638,7 @@ class SQLQuerySuite extends QueryTest { test("SPARK-5203 union with different decimal precision") { Seq.empty[(Decimal, Decimal)] .toDF("d1", "d2") - .select($"d1".cast(DecimalType(10, 15)).as("d")) + .select($"d1".cast(DecimalType(10, 5)).as("d")) .registerTempTable("dn") sql("select d from dn union all select d * 2 from dn") From dca21a83ac33813dd8165acb5f20d06e4f9b9034 Mon Sep 17 00:00:00 2001 From: fe2s Date: Wed, 24 Jun 2015 15:12:23 -0700 Subject: [PATCH 003/122] [SPARK-8558] [BUILD] Script /dev/run-tests fails when _JAVA_OPTIONS env var set Author: fe2s Author: Oleksiy Dyagilev Closes #6956 from fe2s/fix-run-tests and squashes the following commits: 31b6edc [fe2s] str is a built-in function, so using it as a variable name will lead to spurious warnings in some Python linters 7d781a0 [fe2s] fixing for openjdk/IBM, seems like they have slightly different wording, but all have 'version' word. Surrounding with spaces for the case if version word appears in _JAVA_OPTIONS cd455ef [fe2s] address comment, looking for java version string rather than expecting to have on a certain line number ad577d7 [Oleksiy Dyagilev] [SPARK-8558][BUILD] Script /dev/run-tests fails when _JAVA_OPTIONS env var set --- dev/run-tests.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index de1b4537eda5f..e7c09b0f40cdc 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -477,7 +477,12 @@ def determine_java_version(java_exe): raw_output = subprocess.check_output([java_exe, "-version"], stderr=subprocess.STDOUT) - raw_version_str = raw_output.split('\n')[0] # eg 'java version "1.8.0_25"' + + raw_output_lines = raw_output.split('\n') + + # find raw version string, eg 'java version "1.8.0_25"' + raw_version_str = next(x for x in raw_output_lines if " version " in x) + version_str = raw_version_str.split()[-1].strip('"') # eg '1.8.0_25' version, update = version_str.split('_') # eg ['1.8.0', '25'] From 7daa70292e47be6a944351ef00c770ad4bcb0877 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Wed, 24 Jun 2015 15:52:58 -0700 Subject: [PATCH 004/122] [SPARK-8567] [SQL] Increase the timeout of HiveSparkSubmitSuite https://issues.apache.org/jira/browse/SPARK-8567 Author: Yin Huai Closes #6957 from yhuai/SPARK-8567 and squashes the following commits: 62dff5b [Yin Huai] Increase the timeout. --- .../scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index d85516ab0878e..b875e52b986ab 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -95,7 +95,7 @@ class HiveSparkSubmitSuite )) try { - val exitCode = failAfter(120 seconds) { process.exitValue() } + val exitCode = failAfter(180 seconds) { process.exitValue() } if (exitCode != 0) { fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") } From b71d3254e50838ccae43bdb0ff186fda25f03152 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 24 Jun 2015 16:26:00 -0700 Subject: [PATCH 005/122] [SPARK-8075] [SQL] apply type check interface to more expressions a follow up of https://github.com/apache/spark/pull/6405. Note: It's not a big change, a lot of changing is due to I swap some code in `aggregates.scala` to make aggregate functions right below its corresponding aggregate expressions. Author: Wenchen Fan Closes #6723 from cloud-fan/type-check and squashes the following commits: 2124301 [Wenchen Fan] fix tests 5a658bb [Wenchen Fan] add tests 287d3bb [Wenchen Fan] apply type check interface to more expressions --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../catalyst/analysis/HiveTypeCoercion.scala | 17 +- .../spark/sql/catalyst/expressions/Cast.scala | 11 +- .../sql/catalyst/expressions/Expression.scala | 4 +- .../catalyst/expressions/ExtractValue.scala | 10 +- .../sql/catalyst/expressions/aggregates.scala | 420 +++++++++--------- .../sql/catalyst/expressions/arithmetic.scala | 2 - .../expressions/complexTypeCreator.scala | 30 +- .../expressions/decimalFunctions.scala | 17 +- .../sql/catalyst/expressions/generators.scala | 13 +- .../spark/sql/catalyst/expressions/math.scala | 4 +- .../expressions/namedExpressions.scala | 4 +- .../catalyst/expressions/nullFunctions.scala | 27 +- .../spark/sql/catalyst/expressions/sets.scala | 10 +- .../expressions/stringOperations.scala | 2 - .../expressions/windowExpressions.scala | 3 +- .../spark/sql/catalyst/util/TypeUtils.scala | 9 + .../sql/catalyst/analysis/AnalysisSuite.scala | 6 +- .../ExpressionTypeCheckingSuite.scala | 26 +- .../spark/sql/execution/pythonUdfs.scala | 2 +- .../execution/HiveTypeCoercionSuite.scala | 6 - 21 files changed, 337 insertions(+), 290 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/{expressions => analysis}/ExpressionTypeCheckingSuite.scala (84%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b06759f144fd9..cad2c2abe6b1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -587,8 +587,8 @@ class Analyzer( failAnalysis( s"""Expect multiple names given for ${g.getClass.getName}, |but only single name '${name}' specified""".stripMargin) - case Alias(g: Generator, name) => Some((g, name :: Nil)) - case MultiAlias(g: Generator, names) => Some(g, names) + case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil)) + case MultiAlias(g: Generator, names) if g.resolved => Some(g, names) case _ => None } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index d4ab1fc643c33..4ef7341a33245 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -317,6 +317,7 @@ trait HiveTypeCoercion { i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType)))) case Sum(e @ StringType()) => Sum(Cast(e, DoubleType)) + case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType)) case Average(e @ StringType()) => Average(Cast(e, DoubleType)) } } @@ -590,11 +591,12 @@ trait HiveTypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ CreateArray(children) if !a.resolved => - val commonType = a.childTypes.reduce( - (a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType)) - CreateArray( - children.map(c => if (c.dataType == commonType) c else Cast(c, commonType))) + case a @ CreateArray(children) if children.map(_.dataType).distinct.size > 1 => + val types = children.map(_.dataType) + findTightestCommonTypeAndPromoteToString(types) match { + case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) + case None => a + } // Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows. case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest. @@ -620,12 +622,11 @@ trait HiveTypeCoercion { // Coalesce should return the first non-null value, which could be any column // from the list. So we need to make sure the return type is deterministic and // compatible with every child column. - case Coalesce(es) if es.map(_.dataType).distinct.size > 1 => + case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 => val types = es.map(_.dataType) findTightestCommonTypeAndPromoteToString(types) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) - case None => - sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}") + case None => c } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d271434a306dd..8bd7fc18a8dd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging -import org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -31,7 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String /** Cast the child expression to the target data type. */ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { - override lazy val resolved = childrenResolved && resolve(child.dataType, dataType) + override def checkInputDataTypes(): TypeCheckResult = { + if (resolve(child.dataType, dataType)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"cannot cast ${child.dataType} to $dataType") + } + } override def foldable: Boolean = child.foldable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a10a959ae766f..f59db3d5dfc23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -162,9 +162,7 @@ abstract class Expression extends TreeNode[Expression] { /** * Checks the input data types, returns `TypeCheckResult.success` if it's valid, * or returns a `TypeCheckResult` with an error message if invalid. - * Note: it's not valid to call this method until `childrenResolved == true` - * TODO: we should remove the default implementation and implement it for all - * expressions with proper error message. + * Note: it's not valid to call this method until `childrenResolved == true`. */ def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index 4d6c1c265150d..4d7c95ffd1850 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -96,6 +96,11 @@ object ExtractValue { } } +/** + * A common interface of all kinds of extract value expressions. + * Note: concrete extract value expressions are created only by `ExtractValue.apply`, + * we don't need to do type check for them. + */ trait ExtractValue extends UnaryExpression { self: Product => } @@ -179,9 +184,6 @@ case class GetArrayItem(child: Expression, ordinal: Expression) override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType - override lazy val resolved = childrenResolved && - child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType] - protected def evalNotNull(value: Any, ordinal: Any) = { // TODO: consider using Array[_] for ArrayType child to avoid // boxing of primitives @@ -203,8 +205,6 @@ case class GetMapValue(child: Expression, ordinal: Expression) override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType - override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType] - protected def evalNotNull(value: Any, ordinal: Any) = { val baseValue = value.asInstanceOf[Map[Any, _]] baseValue.get(ordinal).orNull diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 00d2e499c5890..a9fc54c548f49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions import com.clearspring.analytics.stream.cardinality.HyperLogLog -import org.apache.spark.sql.catalyst -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet @@ -101,6 +102,9 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[ } override def newInstance(): MinFunction = new MinFunction(child, this) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function min") } case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -132,6 +136,9 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[ } override def newInstance(): MaxFunction = new MaxFunction(child, this) + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForOrderingExpr(child.dataType, "function max") } case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -165,6 +172,21 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def newInstance(): CountFunction = new CountFunction(child, this) } +case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { + def this() = this(null, null) // Required for serialization. + + var count: Long = _ + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + count += 1L + } + } + + override def eval(input: InternalRow): Any = count +} + case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate { def this() = this(null) @@ -183,6 +205,28 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate } } +case class CountDistinctFunction( + @transient expr: Seq[Expression], + @transient base: AggregateExpression) + extends AggregateFunction { + + def this() = this(null, null) // Required for serialization. + + val seen = new OpenHashSet[Any]() + + @transient + val distinctValue = new InterpretedProjection(expr) + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = distinctValue(input) + if (!evaluatedExpr.anyNull) { + seen.add(evaluatedExpr) + } + } + + override def eval(input: InternalRow): Any = seen.size.toLong +} + case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression { def this() = this(null) @@ -278,6 +322,25 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double) } } +case class ApproxCountDistinctPartitionFunction( + expr: Expression, + base: AggregateExpression, + relativeSD: Double) + extends AggregateFunction { + def this() = this(null, null, 0) // Required for serialization. + + private val hyperLogLog = new HyperLogLog(relativeSD) + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + if (evaluatedExpr != null) { + hyperLogLog.offer(evaluatedExpr) + } + } + + override def eval(input: InternalRow): Any = hyperLogLog +} + case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) extends AggregateExpression with trees.UnaryNode[Expression] { @@ -289,6 +352,23 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double) } } +case class ApproxCountDistinctMergeFunction( + expr: Expression, + base: AggregateExpression, + relativeSD: Double) + extends AggregateFunction { + def this() = this(null, null, 0) // Required for serialization. + + private val hyperLogLog = new HyperLogLog(relativeSD) + + override def update(input: InternalRow): Unit = { + val evaluatedExpr = expr.eval(input) + hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog]) + } + + override def eval(input: InternalRow): Any = hyperLogLog.cardinality() +} + case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) extends PartialAggregate with trees.UnaryNode[Expression] { @@ -349,159 +429,9 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN } override def newInstance(): AverageFunction = new AverageFunction(child, this) -} - -case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - - override def nullable: Boolean = true - - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited - case _ => - child.dataType - } - - override def toString: String = s"SUM($child)" - - override def asPartial: SplitEvaluation = { - child.dataType match { - case DecimalType.Fixed(_, _) => - val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() - SplitEvaluation( - Cast(CombineSum(partialSum.toAttribute), dataType), - partialSum :: Nil) - - case _ => - val partialSum = Alias(Sum(child), "PartialSum")() - SplitEvaluation( - CombineSum(partialSum.toAttribute), - partialSum :: Nil) - } - } - - override def newInstance(): SumFunction = new SumFunction(child, this) -} - -/** - * Sum should satisfy 3 cases: - * 1) sum of all null values = zero - * 2) sum for table column with no data = null - * 3) sum of column with null and not null values = sum of not null values - * Require separate CombineSum Expression and function as it has to distinguish "No data" case - * versus "data equals null" case, while aggregating results and at each partial expression.i.e., - * Combining PartitionLevel InputData - * <-- null - * Zero <-- Zero <-- null - * - * <-- null <-- no data - * null <-- null <-- no data - */ -case class CombineSum(child: Expression) extends AggregateExpression { - def this() = this(null) - - override def children: Seq[Expression] = child :: Nil - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"CombineSum($child)" - override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) -} - -case class SumDistinct(child: Expression) - extends PartialAggregate with trees.UnaryNode[Expression] { - - def this() = this(null) - override def nullable: Boolean = true - override def dataType: DataType = child.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive - case DecimalType.Unlimited => - DecimalType.Unlimited - case _ => - child.dataType - } - override def toString: String = s"SUM(DISTINCT $child)" - override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) - - override def asPartial: SplitEvaluation = { - val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() - SplitEvaluation( - CombineSetsAndSum(partialSet.toAttribute, this), - partialSet :: Nil) - } -} -case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { - def this() = this(null, null) - - override def children: Seq[Expression] = inputSet :: Nil - override def nullable: Boolean = true - override def dataType: DataType = base.dataType - override def toString: String = s"CombineAndSum($inputSet)" - override def newInstance(): CombineSetsAndSumFunction = { - new CombineSetsAndSumFunction(inputSet, this) - } -} - -case class CombineSetsAndSumFunction( - @transient inputSet: Expression, - @transient base: AggregateExpression) - extends AggregateFunction { - - def this() = this(null, null) // Required for serialization. - - val seen = new OpenHashSet[Any]() - - override def update(input: InternalRow): Unit = { - val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] - val inputIterator = inputSetEval.iterator - while (inputIterator.hasNext) { - seen.add(inputIterator.next) - } - } - - override def eval(input: InternalRow): Any = { - val casted = seen.asInstanceOf[OpenHashSet[InternalRow]] - if (casted.size == 0) { - null - } else { - Cast(Literal( - casted.iterator.map(f => f.apply(0)).reduceLeft( - base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), - base.dataType).eval(null) - } - } -} - -case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"FIRST($child)" - - override def asPartial: SplitEvaluation = { - val partialFirst = Alias(First(child), "PartialFirst")() - SplitEvaluation( - First(partialFirst.toAttribute), - partialFirst :: Nil) - } - override def newInstance(): FirstFunction = new FirstFunction(child, this) -} - -case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def references: AttributeSet = child.references - override def nullable: Boolean = true - override def dataType: DataType = child.dataType - override def toString: String = s"LAST($child)" - - override def asPartial: SplitEvaluation = { - val partialLast = Alias(Last(child), "PartialLast")() - SplitEvaluation( - Last(partialLast.toAttribute), - partialLast :: Nil) - } - override def newInstance(): LastFunction = new LastFunction(child, this) + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function average") } case class AverageFunction(expr: Expression, base: AggregateExpression) @@ -551,55 +481,41 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) } } -case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { - def this() = this(null, null) // Required for serialization. +case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - var count: Long = _ + override def nullable: Boolean = true - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - count += 1L - } + override def dataType: DataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType } - override def eval(input: InternalRow): Any = count -} - -case class ApproxCountDistinctPartitionFunction( - expr: Expression, - base: AggregateExpression, - relativeSD: Double) - extends AggregateFunction { - def this() = this(null, null, 0) // Required for serialization. + override def toString: String = s"SUM($child)" - private val hyperLogLog = new HyperLogLog(relativeSD) + override def asPartial: SplitEvaluation = { + child.dataType match { + case DecimalType.Fixed(_, _) => + val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() + SplitEvaluation( + Cast(CombineSum(partialSum.toAttribute), dataType), + partialSum :: Nil) - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - if (evaluatedExpr != null) { - hyperLogLog.offer(evaluatedExpr) + case _ => + val partialSum = Alias(Sum(child), "PartialSum")() + SplitEvaluation( + CombineSum(partialSum.toAttribute), + partialSum :: Nil) } } - override def eval(input: InternalRow): Any = hyperLogLog -} - -case class ApproxCountDistinctMergeFunction( - expr: Expression, - base: AggregateExpression, - relativeSD: Double) - extends AggregateFunction { - def this() = this(null, null, 0) // Required for serialization. - - private val hyperLogLog = new HyperLogLog(relativeSD) - - override def update(input: InternalRow): Unit = { - val evaluatedExpr = expr.eval(input) - hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog]) - } + override def newInstance(): SumFunction = new SumFunction(child, this) - override def eval(input: InternalRow): Any = hyperLogLog.cardinality() + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sum") } case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -633,6 +549,30 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr } } +/** + * Sum should satisfy 3 cases: + * 1) sum of all null values = zero + * 2) sum for table column with no data = null + * 3) sum of column with null and not null values = sum of not null values + * Require separate CombineSum Expression and function as it has to distinguish "No data" case + * versus "data equals null" case, while aggregating results and at each partial expression.i.e., + * Combining PartitionLevel InputData + * <-- null + * Zero <-- Zero <-- null + * + * <-- null <-- no data + * null <-- null <-- no data + */ +case class CombineSum(child: Expression) extends AggregateExpression { + def this() = this(null) + + override def children: Seq[Expression] = child :: Nil + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"CombineSum($child)" + override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this) +} + case class CombineSumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -670,6 +610,33 @@ case class CombineSumFunction(expr: Expression, base: AggregateExpression) } } +case class SumDistinct(child: Expression) + extends PartialAggregate with trees.UnaryNode[Expression] { + + def this() = this(null) + override def nullable: Boolean = true + override def dataType: DataType = child.dataType match { + case DecimalType.Fixed(precision, scale) => + DecimalType(precision + 10, scale) // Add 10 digits left of decimal point, like Hive + case DecimalType.Unlimited => + DecimalType.Unlimited + case _ => + child.dataType + } + override def toString: String = s"SUM(DISTINCT $child)" + override def newInstance(): SumDistinctFunction = new SumDistinctFunction(child, this) + + override def asPartial: SplitEvaluation = { + val partialSet = Alias(CollectHashSet(child :: Nil), "partialSets")() + SplitEvaluation( + CombineSetsAndSum(partialSet.toAttribute, this), + partialSet :: Nil) + } + + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct") +} + case class SumDistinctFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -696,8 +663,20 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) } } -case class CountDistinctFunction( - @transient expr: Seq[Expression], +case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression { + def this() = this(null, null) + + override def children: Seq[Expression] = inputSet :: Nil + override def nullable: Boolean = true + override def dataType: DataType = base.dataType + override def toString: String = s"CombineAndSum($inputSet)" + override def newInstance(): CombineSetsAndSumFunction = { + new CombineSetsAndSumFunction(inputSet, this) + } +} + +case class CombineSetsAndSumFunction( + @transient inputSet: Expression, @transient base: AggregateExpression) extends AggregateFunction { @@ -705,17 +684,39 @@ case class CountDistinctFunction( val seen = new OpenHashSet[Any]() - @transient - val distinctValue = new InterpretedProjection(expr) - override def update(input: InternalRow): Unit = { - val evaluatedExpr = distinctValue(input) - if (!evaluatedExpr.anyNull) { - seen.add(evaluatedExpr) + val inputSetEval = inputSet.eval(input).asInstanceOf[OpenHashSet[Any]] + val inputIterator = inputSetEval.iterator + while (inputIterator.hasNext) { + seen.add(inputIterator.next) } } - override def eval(input: InternalRow): Any = seen.size.toLong + override def eval(input: InternalRow): Any = { + val casted = seen.asInstanceOf[OpenHashSet[InternalRow]] + if (casted.size == 0) { + null + } else { + Cast(Literal( + casted.iterator.map(f => f.apply(0)).reduceLeft( + base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), + base.dataType).eval(null) + } + } +} + +case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"FIRST($child)" + + override def asPartial: SplitEvaluation = { + val partialFirst = Alias(First(child), "PartialFirst")() + SplitEvaluation( + First(partialFirst.toAttribute), + partialFirst :: Nil) + } + override def newInstance(): FirstFunction = new FirstFunction(child, this) } case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -732,6 +733,21 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag override def eval(input: InternalRow): Any = result } +case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { + override def references: AttributeSet = child.references + override def nullable: Boolean = true + override def dataType: DataType = child.dataType + override def toString: String = s"LAST($child)" + + override def asPartial: SplitEvaluation = { + val partialLast = Alias(Last(child), "PartialLast")() + SplitEvaluation( + Last(partialLast.toAttribute), + partialLast :: Nil) + } + override def newInstance(): LastFunction = new LastFunction(child, this) +} + case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { def this() = this(null, null) // Required for serialization. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index ace8427c8ddaf..3d4d9e2d798f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -25,8 +25,6 @@ import org.apache.spark.sql.types._ abstract class UnaryArithmetic extends UnaryExpression { self: Product => - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def dataType: DataType = child.dataType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index e0bf07ed182f3..5def57b067424 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ - /** * Returns an Array containing the evaluation of all children expressions. */ @@ -27,15 +28,12 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - lazy val childTypes = children.map(_.dataType).distinct - - override lazy val resolved = - childrenResolved && childTypes.size <= 1 + override def checkInputDataTypes(): TypeCheckResult = + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array") override def dataType: DataType = { - assert(resolved, s"Invalid dataType of mixed ArrayType ${childTypes.mkString(",")}") ArrayType( - childTypes.headOption.getOrElse(NullType), + children.headOption.map(_.dataType).getOrElse(NullType), containsNull = children.exists(_.nullable)) } @@ -56,19 +54,15 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) - override lazy val resolved: Boolean = childrenResolved - override lazy val dataType: StructType = { - assert(resolved, - s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.") - val fields = children.zipWithIndex.map { case (child, idx) => - child match { - case ne: NamedExpression => - StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) - case _ => - StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) - } + val fields = children.zipWithIndex.map { case (child, idx) => + child match { + case ne: NamedExpression => + StructField(ne.name, ne.dataType, ne.nullable, ne.metadata) + case _ => + StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty) } + } StructType(fields) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala index 2bc893af02641..f5c2dde191cf3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types._ -/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */ +/** + * Return the unscaled Long value of a Decimal, assuming it fits in a Long. + * Note: this expression is internal and created only by the optimizer, + * we don't need to do type check for it. + */ case class UnscaledValue(child: Expression) extends UnaryExpression { override def dataType: DataType = LongType - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"UnscaledValue($child)" override def eval(input: InternalRow): Any = { @@ -43,12 +44,14 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { } } -/** Create a Decimal from an unscaled Long value */ +/** + * Create a Decimal from an unscaled Long value. + * Note: this expression is internal and created only by the optimizer, + * we don't need to do type check for it. + */ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { override def dataType: DataType = DecimalType(precision, scale) - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"MakeDecimal($child,$precision,$scale)" override def eval(input: InternalRow): Decimal = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index f30cb42d12b83..356560e54cae3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map -import org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.types._ @@ -100,9 +100,14 @@ case class UserDefinedGenerator( case class Explode(child: Expression) extends Generator with trees.UnaryNode[Expression] { - override lazy val resolved = - child.resolved && - (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) + override def checkInputDataTypes(): TypeCheckResult = { + if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"input to function explode should be array or map type, not ${child.dataType}") + } + } override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match { case ArrayType(et, containsNull) => (et, containsNull) :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 250564dc4b818..5694afc61be05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import java.lang.{Long => JLong} -import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -60,7 +59,6 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType - override def foldable: Boolean = child.foldable override def nullable: Boolean = true override def toString: String = s"$name($child)" @@ -224,7 +222,7 @@ case class Bin(child: Expression) def funcName: String = name.toLowerCase - override def eval(input: catalyst.InternalRow): Any = { + override def eval(input: InternalRow): Any = { val evalE = child.eval(input) if (evalE == null) { null diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 9cacdceb13837..6f56a9ec7beb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} @@ -113,7 +112,8 @@ case class Alias(child: Expression, name: String)( extends NamedExpression with trees.UnaryNode[Expression] { // Alias(Generator, xx) need to be transformed into Generate(generator, ...) - override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator] + override lazy val resolved = + childrenResolved && checkInputDataTypes().isSuccess && !child.isInstanceOf[Generator] override def eval(input: InternalRow): Any = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 98acaf23c44c1..5d5911403ece1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -17,33 +17,32 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.analysis.UnresolvedException +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types.DataType case class Coalesce(children: Seq[Expression]) extends Expression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ - override def nullable: Boolean = !children.exists(!_.nullable) + override def nullable: Boolean = children.forall(_.nullable) // Coalesce is foldable if all children are foldable. - override def foldable: Boolean = !children.exists(!_.foldable) + override def foldable: Boolean = children.forall(_.foldable) - // Only resolved if all the children are of the same type. - override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1) + override def checkInputDataTypes(): TypeCheckResult = { + if (children == Nil) { + TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty") + } else { + TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce") + } + } override def toString: String = s"Coalesce(${children.mkString(",")})" - override def dataType: DataType = if (resolved) { - children.head.dataType - } else { - val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ") - throw new UnresolvedException( - this, s"Coalesce cannot have children of different types. $childTypes") - } + override def dataType: DataType = children.head.dataType override def eval(input: InternalRow): Any = { - var i = 0 var result: Any = null val childIterator = children.iterator while (childIterator.hasNext && result == null) { @@ -75,7 +74,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } case class IsNull(child: Expression) extends UnaryExpression with Predicate { - override def foldable: Boolean = child.foldable override def nullable: Boolean = false override def eval(input: InternalRow): Any = { @@ -93,7 +91,6 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { } case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { - override def foldable: Boolean = child.foldable override def nullable: Boolean = false override def toString: String = s"IS NOT NULL $child" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index 30e41677b774b..efc6f50b78943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -78,6 +78,8 @@ case class NewSet(elementType: DataType) extends LeafExpression { /** * Adds an item to a set. * For performance, this expression mutates its input during evaluation. + * Note: this expression is internal and created only by the GeneratedAggregate, + * we don't need to do type check for it. */ case class AddItemToSet(item: Expression, set: Expression) extends Expression { @@ -85,7 +87,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { override def nullable: Boolean = set.nullable - override def dataType: OpenHashSetUDT = set.dataType.asInstanceOf[OpenHashSetUDT] + override def dataType: DataType = set.dataType override def eval(input: InternalRow): Any = { val itemEval = item.eval(input) @@ -128,12 +130,14 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { /** * Combines the elements of two sets. * For performance, this expression mutates its left input set during evaluation. + * Note: this expression is internal and created only by the GeneratedAggregate, + * we don't need to do type check for it. */ case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { override def nullable: Boolean = left.nullable || right.nullable - override def dataType: OpenHashSetUDT = left.dataType.asInstanceOf[OpenHashSetUDT] + override def dataType: DataType = left.dataType override def symbol: String = "++=" @@ -176,6 +180,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres /** * Returns the number of elements in the input set. + * Note: this expression is internal and created only by the GeneratedAggregate, + * we don't need to do type check for it. */ case class CountSet(child: Expression) extends UnaryExpression { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 315c63e63c635..44416e79cd7aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -117,8 +117,6 @@ trait CaseConversionExpression extends ExpectsInputTypes { def convert(v: UTF8String): UTF8String - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def dataType: DataType = StringType override def expectedChildTypes: Seq[DataType] = Seq(StringType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 896e383f50eac..12023ad311dc8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -68,7 +68,8 @@ case class WindowSpecDefinition( override def children: Seq[Expression] = partitionSpec ++ orderSpec override lazy val resolved: Boolean = - childrenResolved && frameSpecification.isInstanceOf[SpecifiedWindowFrame] + childrenResolved && checkInputDataTypes().isSuccess && + frameSpecification.isInstanceOf[SpecifiedWindowFrame] override def toString: String = simpleString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 04857a23f4c1e..8656cc334d09f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -48,6 +48,15 @@ object TypeUtils { } } + def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = { + if (types.distinct.size > 1) { + TypeCheckResult.TypeCheckFailure( + s"input to $caller should all be the same type, but it's ${types.mkString("[", ", ", "]")}") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + def getNumeric(t: DataType): Numeric[Any] = t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index e09cd790a7187..77ca080f366cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -193,7 +193,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { errorTest( "bad casts", testRelation.select(Literal(1).cast(BinaryType).as('badCast)), - "invalid cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) + "cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil) errorTest( "non-boolean filters", @@ -264,9 +264,9 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter { val plan = Aggregate( Nil, - Alias(Sum(AttributeReference("a", StringType)(exprId = ExprId(1))), "b")() :: Nil, + Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil, LocalRelation( - AttributeReference("a", StringType)(exprId = ExprId(2)))) + AttributeReference("a", IntegerType)(exprId = ExprId(2)))) assert(plan.resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala similarity index 84% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 49b111989799b..bc1537b0715b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -15,13 +15,13 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst.expressions +package org.apache.spark.sql.catalyst.analysis import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.StringType @@ -136,6 +136,28 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError( CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)), "WHEN expressions in CaseWhen should all be boolean type") + } + + test("check types for aggregates") { + // We will cast String to Double for sum and average + assertSuccess(Sum('stringField)) + assertSuccess(SumDistinct('stringField)) + assertSuccess(Average('stringField)) + + assertError(Min('complexField), "function min accepts non-complex type") + assertError(Max('complexField), "function max accepts non-complex type") + assertError(Sum('booleanField), "function sum accepts numeric type") + assertError(SumDistinct('booleanField), "function sumDistinct accepts numeric type") + assertError(Average('booleanField), "function average accepts numeric type") + } + test("check types for others") { + assertError(CreateArray(Seq('intField, 'booleanField)), + "input to function array should all be the same type") + assertError(Coalesce(Seq('intField, 'booleanField)), + "input to function coalesce should all be the same type") + assertError(Coalesce(Nil), "input to function coalesce cannot be empty") + assertError(Explode('intField), + "input to function explode should be array or map type") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index 6db551c543a9c..f9c3fe92c2670 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -55,7 +55,7 @@ private[spark] case class PythonUDF( override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" - def nullable: Boolean = true + override def nullable: Boolean = true override def eval(input: InternalRow): Any = { throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index f0f04f8c73fb4..197e9bfb02c4e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -59,10 +59,4 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { } assert(numEquals === 1) } - - test("COALESCE with different types") { - intercept[RuntimeException] { - TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect() - } - } } From 82f80c1c7dc42c11bca2b6832c10f9610a43391b Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 24 Jun 2015 19:34:07 -0700 Subject: [PATCH 006/122] Two minor SQL cleanup (compiler warning & indent). Author: Reynold Xin Closes #7000 from rxin/minor-cleanup and squashes the following commits: 046044c [Reynold Xin] Two minor SQL cleanup (compiler warning & indent). --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 4 ++-- .../apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index cad2c2abe6b1a..117c87a785fdb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -309,8 +309,8 @@ class Analyzer( .nonEmpty => (oldVersion, oldVersion.copy(windowExpressions = newAliases(windowExpressions))) } - // Only handle first case, others will be fixed on the next pass. - .headOption match { + // Only handle first case, others will be fixed on the next pass. + .headOption match { case None => /* * No result implies that there is a logical plan node that produces new references diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 4ef7341a33245..976fa57cb98d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -678,8 +678,8 @@ trait HiveTypeCoercion { findTightestCommonTypeAndPromoteToString((c.key +: c.whenList).map(_.dataType)) maybeCommonType.map { commonType => val castedBranches = c.branches.grouped(2).map { - case Seq(when, then) if when.dataType != commonType => - Seq(Cast(when, commonType), then) + case Seq(whenExpr, thenExpr) if whenExpr.dataType != commonType => + Seq(Cast(whenExpr, commonType), thenExpr) case other => other }.reduce(_ ++ _) CaseKeyWhen(Cast(c.key, commonType), castedBranches) From 7bac2fe7717c0102b4875dbd95ae0bbf964536e3 Mon Sep 17 00:00:00 2001 From: Matt Massie Date: Wed, 24 Jun 2015 22:09:31 -0700 Subject: [PATCH 007/122] [SPARK-7884] Move block deserialization from BlockStoreShuffleFetcher to ShuffleReader This commit updates the shuffle read path to enable ShuffleReader implementations more control over the deserialization process. The BlockStoreShuffleFetcher.fetch() method has been renamed to BlockStoreShuffleFetcher.fetchBlockStreams(). Previously, this method returned a record iterator; now, it returns an iterator of (BlockId, InputStream). Deserialization of records is now handled in the ShuffleReader.read() method. This change creates a cleaner separation of concerns and allows implementations of ShuffleReader more flexibility in how records are retrieved. Author: Matt Massie Author: Kay Ousterhout Closes #6423 from massie/shuffle-api-cleanup and squashes the following commits: 8b0632c [Matt Massie] Minor Scala style fixes d0a1b39 [Matt Massie] Merge pull request #1 from kayousterhout/massie_shuffle-api-cleanup 290f1eb [Kay Ousterhout] Added test for HashShuffleReader.read() 5186da0 [Kay Ousterhout] Revert "Add test to ensure HashShuffleReader is freeing resources" f98a1b9 [Matt Massie] Add test to ensure HashShuffleReader is freeing resources a011bfa [Matt Massie] Use PrivateMethodTester on check that delegate stream is closed 4ea1712 [Matt Massie] Small code cleanup for readability 7429a98 [Matt Massie] Update tests to check that BufferReleasingStream is closing delegate InputStream f458489 [Matt Massie] Remove unnecessary map() on return Iterator 4abb855 [Matt Massie] Consolidate metric code. Make it clear why InterrubtibleIterator is needed. 5c30405 [Matt Massie] Return visibility of BlockStoreShuffleFetcher to private[hash] 7eedd1d [Matt Massie] Small Scala import cleanup 28f8085 [Matt Massie] Small import nit f93841e [Matt Massie] Update shuffle read metrics in ShuffleReader instead of BlockStoreShuffleFetcher. 7e8e0fe [Matt Massie] Minor Scala style fixes 01e8721 [Matt Massie] Explicitly cast iterator in branches for type clarity 7c8f73e [Matt Massie] Close Block InputStream immediately after all records are read 208b7a5 [Matt Massie] Small code style changes b70c945 [Matt Massie] Make BlockStoreShuffleFetcher visible to shuffle package 19135f2 [Matt Massie] [SPARK-7884] Allow Spark shuffle APIs to be more customizable --- .../hash/BlockStoreShuffleFetcher.scala | 59 +++---- .../shuffle/hash/HashShuffleReader.scala | 52 +++++- .../storage/ShuffleBlockFetcherIterator.scala | 90 +++++++---- .../shuffle/hash/HashShuffleReaderSuite.scala | 150 ++++++++++++++++++ .../ShuffleBlockFetcherIteratorSuite.scala | 59 ++++--- 5 files changed, 314 insertions(+), 96 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala index 597d46a3d2223..9d8e7e9f03aea 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala @@ -17,29 +17,29 @@ package org.apache.spark.shuffle.hash -import scala.collection.mutable.ArrayBuffer -import scala.collection.mutable.HashMap -import scala.util.{Failure, Success, Try} +import java.io.InputStream + +import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.util.{Failure, Success} import org.apache.spark._ -import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId} -import org.apache.spark.util.CompletionIterator +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator, + ShuffleBlockId} private[hash] object BlockStoreShuffleFetcher extends Logging { - def fetch[T]( + def fetchBlockStreams( shuffleId: Int, reduceId: Int, context: TaskContext, - serializer: Serializer) - : Iterator[T] = + blockManager: BlockManager, + mapOutputTracker: MapOutputTracker) + : Iterator[(BlockId, InputStream)] = { logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) - val blockManager = SparkEnv.get.blockManager val startTime = System.currentTimeMillis - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId) + val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId) logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( shuffleId, reduceId, System.currentTimeMillis - startTime)) @@ -53,12 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) } - def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = { + val blockFetcherItr = new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + blocksByAddress, + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) + + // Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler + blockFetcherItr.map { blockPair => val blockId = blockPair._1 val blockOption = blockPair._2 blockOption match { - case Success(block) => { - block.asInstanceOf[Iterator[T]] + case Success(inputStream) => { + (blockId, inputStream) } case Failure(e) => { blockId match { @@ -72,27 +81,5 @@ private[hash] object BlockStoreShuffleFetcher extends Logging { } } } - - val blockFetcherItr = new ShuffleBlockFetcherIterator( - context, - SparkEnv.get.blockManager.shuffleClient, - blockManager, - blocksByAddress, - serializer, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) - val itr = blockFetcherItr.flatMap(unpackBlock) - - val completionIter = CompletionIterator[T, Iterator[T]](itr, { - context.taskMetrics.updateShuffleReadMetrics() - }) - - new InterruptibleIterator[T](context, completionIter) { - val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - override def next(): T = { - readMetrics.incRecordsRead(1) - delegate.next() - } - } } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index 41bafabde05b9..d5c9880659dd3 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,16 +17,20 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} +import org.apache.spark.storage.BlockManager +import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter private[spark] class HashShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], startPartition: Int, endPartition: Int, - context: TaskContext) + context: TaskContext, + blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] { require(endPartition == startPartition + 1, @@ -36,20 +40,52 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { + val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( + handle.shuffleId, startPartition, context, blockManager, mapOutputTracker) + + // Wrap the streams for compression based on configuration + val wrappedStreams = blockStreams.map { case (blockId, inputStream) => + blockManager.wrapForCompression(blockId, inputStream) + } + val ser = Serializer.getSerializer(dep.serializer) - val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser) + val serializerInstance = ser.newInstance() + + // Create a key/value iterator for each stream + val recordIter = wrappedStreams.flatMap { wrappedStream => + // Note: the asKeyValueIterator below wraps a key/value iterator inside of a + // NextIterator. The NextIterator makes sure that close() is called on the + // underlying InputStream when all records have been read. + serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator + } + + // Update the context task metrics for each record read. + val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() + val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( + recordIter.map(record => { + readMetrics.incRecordsRead(1) + record + }), + context.taskMetrics().updateShuffleReadMetrics()) + + // An interruptible iterator must be used here in order to support task cancellation + val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter) val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { - new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context)) + // We are reading values that are already combined + val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]] + dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context) } else { - new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context)) + // We don't know the value type, but also don't care -- the dependency *should* + // have made sure its compatible w/ this aggregator, which will convert the value + // type to the combined type C + val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]] + dep.aggregator.get.combineValuesByKey(keyValuesIterator, context) } } else { require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") - - // Convert the Product2s to pairs since this is what downstream RDDs currently expect - iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2)) + interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]] } // Sort the output if there is a sort ordering defined. diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d0faab62c9e9e..e49e39679e940 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -17,23 +17,23 @@ package org.apache.spark.storage +import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} import scala.util.{Failure, Try} import org.apache.spark.{Logging, TaskContext} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.serializer.{SerializerInstance, Serializer} -import org.apache.spark.util.{CompletionIterator, Utils} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.util.Utils /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block * manager. For remote blocks, it fetches them using the provided BlockTransferService. * - * This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a - * pipelined fashion as they are received. + * This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks + * in a pipelined fashion as they are received. * * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid * using too much memory. @@ -44,7 +44,6 @@ import org.apache.spark.util.{CompletionIterator, Utils} * @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]]. * For each block we also require the size (in bytes as a long field) in * order to throttle the memory usage. - * @param serializer serializer used to deserialize the data. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. */ private[spark] @@ -53,9 +52,8 @@ final class ShuffleBlockFetcherIterator( shuffleClient: ShuffleClient, blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], - serializer: Serializer, maxBytesInFlight: Long) - extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging { + extends Iterator[(BlockId, Try[InputStream])] with Logging { import ShuffleBlockFetcherIterator._ @@ -83,7 +81,7 @@ final class ShuffleBlockFetcherIterator( /** * A queue to hold our results. This turns the asynchronous model provided by - * [[BlockTransferService]] into a synchronous model (iterator). + * [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator). */ private[this] val results = new LinkedBlockingQueue[FetchResult] @@ -102,9 +100,7 @@ final class ShuffleBlockFetcherIterator( /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L - private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency() - - private[this] val serializerInstance: SerializerInstance = serializer.newInstance() + private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency() /** * Whether the iterator is still active. If isZombie is true, the callback interface will no @@ -114,17 +110,23 @@ final class ShuffleBlockFetcherIterator( initialize() - /** - * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. - */ - private[this] def cleanup() { - isZombie = true + // Decrements the buffer reference count. + // The currentResult is set to null to prevent releasing the buffer again on cleanup() + private[storage] def releaseCurrentResultBuffer(): Unit = { // Release the current buffer if necessary currentResult match { case SuccessFetchResult(_, _, buf) => buf.release() case _ => } + currentResult = null + } + /** + * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet. + */ + private[this] def cleanup() { + isZombie = true + releaseCurrentResultBuffer() // Release buffers in the results queue val iter = results.iterator() while (iter.hasNext) { @@ -272,7 +274,13 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch - override def next(): (BlockId, Try[Iterator[Any]]) = { + /** + * Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers + * underlying each InputStream will be freed by the cleanup() method registered with the + * TaskCompletionListener. However, callers should close() these InputStreams + * as soon as they are no longer needed, in order to release memory as early as possible. + */ + override def next(): (BlockId, Try[InputStream]) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() currentResult = results.take() @@ -290,22 +298,15 @@ final class ShuffleBlockFetcherIterator( sendRequest(fetchRequests.dequeue()) } - val iteratorTry: Try[Iterator[Any]] = result match { + val iteratorTry: Try[InputStream] = result match { case FailureFetchResult(_, e) => Failure(e) case SuccessFetchResult(blockId, _, buf) => // There is a chance that createInputStream can fail (e.g. fetching a local file that does // not exist, SPARK-4085). In that case, we should propagate the right exception so // the scheduler gets a FetchFailedException. - Try(buf.createInputStream()).map { is0 => - val is = blockManager.wrapForCompression(blockId, is0) - val iter = serializerInstance.deserializeStream(is).asKeyValueIterator - CompletionIterator[Any, Iterator[Any]](iter, { - // Once the iterator is exhausted, release the buffer and set currentResult to null - // so we don't release it again in cleanup. - currentResult = null - buf.release() - }) + Try(buf.createInputStream()).map { inputStream => + new BufferReleasingInputStream(inputStream, this) } } @@ -313,6 +314,39 @@ final class ShuffleBlockFetcherIterator( } } +/** + * Helper class that ensures a ManagedBuffer is release upon InputStream.close() + */ +private class BufferReleasingInputStream( + private val delegate: InputStream, + private val iterator: ShuffleBlockFetcherIterator) + extends InputStream { + private[this] var closed = false + + override def read(): Int = delegate.read() + + override def close(): Unit = { + if (!closed) { + delegate.close() + iterator.releaseCurrentResultBuffer() + closed = true + } + } + + override def available(): Int = delegate.available() + + override def mark(readlimit: Int): Unit = delegate.mark(readlimit) + + override def skip(n: Long): Long = delegate.skip(n) + + override def markSupported(): Boolean = delegate.markSupported() + + override def read(b: Array[Byte]): Int = delegate.read(b) + + override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len) + + override def reset(): Unit = delegate.reset() +} private[storage] object ShuffleBlockFetcherIterator { diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala new file mode 100644 index 0000000000000..28ca68698e3dc --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.hash + +import java.io.{ByteArrayOutputStream, InputStream} +import java.nio.ByteBuffer + +import org.mockito.Matchers.{eq => meq, _} +import org.mockito.Mockito.{mock, when} +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer + +import org.apache.spark._ +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.BaseShuffleHandle +import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} + +/** + * Wrapper for a managed buffer that keeps track of how many times retain and release are called. + * + * We need to define this class ourselves instead of using a spy because the NioManagedBuffer class + * is final (final classes cannot be spied on). + */ +class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer { + var callsToRetain = 0 + var callsToRelease = 0 + + override def size(): Long = underlyingBuffer.size() + override def nioByteBuffer(): ByteBuffer = underlyingBuffer.nioByteBuffer() + override def createInputStream(): InputStream = underlyingBuffer.createInputStream() + override def convertToNetty(): AnyRef = underlyingBuffer.convertToNetty() + + override def retain(): ManagedBuffer = { + callsToRetain += 1 + underlyingBuffer.retain() + } + override def release(): ManagedBuffer = { + callsToRelease += 1 + underlyingBuffer.release() + } +} + +class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { + + /** + * This test makes sure that, when data is read from a HashShuffleReader, the underlying + * ManagedBuffers that contain the data are eventually released. + */ + test("read() releases resources on completion") { + val testConf = new SparkConf(false) + // Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the + // shuffle code calls SparkEnv.get()). + sc = new SparkContext("local", "test", testConf) + + val reduceId = 15 + val shuffleId = 22 + val numMaps = 6 + val keyValuePairsPerMap = 10 + val serializer = new JavaSerializer(testConf) + + // Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we + // can ensure retain() and release() are properly called. + val blockManager = mock(classOf[BlockManager]) + + // Create a return function to use for the mocked wrapForCompression method that just returns + // the original input stream. + val dummyCompressionFunction = new Answer[InputStream] { + override def answer(invocation: InvocationOnMock): InputStream = + invocation.getArguments()(1).asInstanceOf[InputStream] + } + + // Create a buffer with some randomly generated key-value pairs to use as the shuffle data + // from each mappers (all mappers return the same shuffle data). + val byteOutputStream = new ByteArrayOutputStream() + val serializationStream = serializer.newInstance().serializeStream(byteOutputStream) + (0 until keyValuePairsPerMap).foreach { i => + serializationStream.writeKey(i) + serializationStream.writeValue(2*i) + } + + // Setup the mocked BlockManager to return RecordingManagedBuffers. + val localBlockManagerId = BlockManagerId("test-client", "test-client", 1) + when(blockManager.blockManagerId).thenReturn(localBlockManagerId) + val buffers = (0 until numMaps).map { mapId => + // Create a ManagedBuffer with the shuffle data. + val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray)) + val managedBuffer = new RecordingManagedBuffer(nioBuffer) + + // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to + // fetch shuffle data. + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) + when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream]))) + .thenAnswer(dummyCompressionFunction) + + managedBuffer + } + + // Make a mocked MapOutputTracker for the shuffle reader to use to determine what + // shuffle data to read. + val mapOutputTracker = mock(classOf[MapOutputTracker]) + // Test a scenario where all data is local, just to avoid creating a bunch of additional mocks + // for the code to read data over the network. + val statuses: Array[(BlockManagerId, Long)] = + Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size().toLong)) + when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses) + + // Create a mocked shuffle handle to pass into HashShuffleReader. + val shuffleHandle = { + val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]]) + when(dependency.serializer).thenReturn(Some(serializer)) + when(dependency.aggregator).thenReturn(None) + when(dependency.keyOrdering).thenReturn(None) + new BaseShuffleHandle(shuffleId, numMaps, dependency) + } + + val shuffleReader = new HashShuffleReader( + shuffleHandle, + reduceId, + reduceId + 1, + new TaskContextImpl(0, 0, 0, 0, null), + blockManager, + mapOutputTracker) + + assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) + + // Calling .length above will have exhausted the iterator; make sure that exhausting the + // iterator caused retain and release to be called on each buffer. + buffers.foreach { buffer => + assert(buffer.callsToRetain === 1) + assert(buffer.callsToRelease === 1) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 2a7fe67ad8585..9ced4148d7206 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -17,23 +17,25 @@ package org.apache.spark.storage +import java.io.InputStream import java.util.concurrent.Semaphore -import scala.concurrent.future import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.future import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.scalatest.PrivateMethodTester -import org.apache.spark.{SparkConf, SparkFunSuite, TaskContextImpl} +import org.apache.spark.{SparkFunSuite, TaskContextImpl} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener -import org.apache.spark.serializer.TestSerializer -class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { + +class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { // Some of the tests are quite tricky because we are testing the cleanup behavior // in the presence of faults. @@ -57,7 +59,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer } - private val conf = new SparkConf + // Create a mock managed buffer for testing + def createMockManagedBuffer(): ManagedBuffer = { + val mockManagedBuffer = mock(classOf[ManagedBuffer]) + when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf[InputStream])) + mockManagedBuffer + } test("successful 3 local reads + 2 remote reads") { val blockManager = mock(classOf[BlockManager]) @@ -66,9 +73,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure blockManager.getBlockData would return the blocks val localBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])) + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) localBlocks.foreach { case (blockId, buf) => doReturn(buf).when(blockManager).getBlockData(meq(blockId)) } @@ -76,9 +83,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val remoteBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer]) - ) + ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 4, 0) -> createMockManagedBuffer()) val transfer = createMockTransfer(remoteBlocks) @@ -92,7 +98,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) // 3 local blocks fetched in initialization @@ -100,15 +105,24 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") - val (blockId, subIterator) = iterator.next() - assert(subIterator.isSuccess, + val (blockId, inputStream) = iterator.next() + assert(inputStream.isSuccess, s"iterator should have 5 elements defined but actually has $i elements") - // Make sure we release the buffer once the iterator is exhausted. + // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) + // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream + val wrappedInputStream = inputStream.get.asInstanceOf[BufferReleasingInputStream] verify(mockBuf, times(0)).release() - subIterator.get.foreach(_ => Unit) // exhaust the iterator + val delegateAccess = PrivateMethod[InputStream]('delegate) + + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(0)).close() + wrappedInputStream.close() + verify(mockBuf, times(1)).release() + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() + wrappedInputStream.close() // close should be idempotent verify(mockBuf, times(1)).release() + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() } // 3 local blocks, and 2 remote blocks @@ -125,10 +139,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]), - ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]) - ) + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) @@ -159,11 +172,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) - // Exhaust the first block, and then it should be released. - iterator.next()._2.get.foreach(_ => Unit) + verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() + iterator.next()._2.get.close() // close() first block's input stream verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() // Get the 2nd block but do not exhaust the iterator @@ -222,7 +234,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite { transfer, blockManager, blocksByAddress, - new TestSerializer, 48 * 1024 * 1024) // Continue only after the mock calls onBlockFetchFailure From c337844ed7f9b2cb7b217dc935183ef5e1096ca1 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Thu, 25 Jun 2015 00:06:23 -0700 Subject: [PATCH 008/122] [SPARK-8604] [SQL] HadoopFsRelation subclasses should set their output format class `HadoopFsRelation` subclasses, especially `ParquetRelation2` should set its own output format class, so that the default output committer can be setup correctly when doing appending (where we ignore user defined output committers). Author: Cheng Lian Closes #6998 from liancheng/spark-8604 and squashes the following commits: 9be51d1 [Cheng Lian] Adds more comments 6db1368 [Cheng Lian] HadoopFsRelation subclasses should set their output format class --- .../apache/spark/sql/parquet/newParquet.scala | 6 ++++++ .../spark/sql/hive/orc/OrcRelation.scala | 12 ++++++++++- .../sql/sources/SimpleTextRelation.scala | 2 ++ .../sql/sources/hadoopFsRelationSuites.scala | 21 +++++++++++++++++++ 4 files changed, 40 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 1d353bd8e1114..bc39fae2bcfde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -194,6 +194,12 @@ private[sql] class ParquetRelation2( committerClass, classOf[ParquetOutputCommitter]) + // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override + // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why + // we set it here is to setup the output committer class to `ParquetOutputCommitter`, which is + // bundled with `ParquetOutputFormat[Row]`. + job.setOutputFormatClass(classOf[ParquetOutputFormat[Row]]) + // TODO There's no need to use two kinds of WriteSupport // We should unify them. `SpecificMutableRow` can process both atomic (primitive) types and // complex types. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 705f48f1cd9f0..0fd7b3a91d6dd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.hive.ql.io.orc.{OrcInputFormat, OrcOutputFormat, OrcSer import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils import org.apache.hadoop.io.{NullWritable, Writable} -import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, RecordWriter, Reporter} +import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} @@ -194,6 +194,16 @@ private[sql] class OrcRelation( } override def prepareJobForWrite(job: Job): OutputWriterFactory = { + job.getConfiguration match { + case conf: JobConf => + conf.setOutputFormat(classOf[OrcOutputFormat]) + case conf => + conf.setClass( + "mapred.output.format.class", + classOf[OrcOutputFormat], + classOf[MapRedOutputFormat[_, _]]) + } + new OutputWriterFactory { override def newInstance( path: String, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 5d7cd16c129cd..e8141923a9b5c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -119,6 +119,8 @@ class SimpleTextRelation( } override def prepareJobForWrite(job: Job): OutputWriterFactory = new OutputWriterFactory { + job.setOutputFormatClass(classOf[TextOutputFormat[_, _]]) + override def newInstance( path: String, dataSchema: StructType, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index a16ab3a00ddb8..afecf9675e11f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -719,4 +719,25 @@ class ParquetHadoopFsRelationSuite extends HadoopFsRelationTest { } } } + + test("SPARK-8604: Parquet data source should write summary file while doing appending") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = sqlContext.range(0, 5) + df.write.mode(SaveMode.Overwrite).parquet(path) + + val summaryPath = new Path(path, "_metadata") + val commonSummaryPath = new Path(path, "_common_metadata") + + val fs = summaryPath.getFileSystem(configuration) + fs.delete(summaryPath, true) + fs.delete(commonSummaryPath, true) + + df.write.mode(SaveMode.Append).parquet(path) + checkAnswer(sqlContext.read.parquet(path), df.unionAll(df)) + + assert(fs.exists(summaryPath)) + assert(fs.exists(commonSummaryPath)) + } + } } From 085a7216bf5e6c2b4f297feca4af71a751e37975 Mon Sep 17 00:00:00 2001 From: Joshi Date: Thu, 25 Jun 2015 20:21:34 +0900 Subject: [PATCH 009/122] [SPARK-5768] [WEB UI] Fix for incorrect memory in Spark UI Fix for incorrect memory in Spark UI as per SPARK-5768 Author: Joshi Author: Rekha Joshi Closes #6972 from rekhajoshm/SPARK-5768 and squashes the following commits: b678a91 [Joshi] Fix for incorrect memory in Spark UI 2fe53d9 [Joshi] Fix for incorrect memory in Spark UI eb823b8 [Joshi] SPARK-5768: Fix for incorrect memory in Spark UI 0be142d [Rekha Joshi] Merge pull request #3 from apache/master 106fd8e [Rekha Joshi] Merge pull request #2 from apache/master e3677c9 [Rekha Joshi] Merge pull request #1 from apache/master --- core/src/main/scala/org/apache/spark/ui/ToolTips.scala | 4 ++++ .../main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala index 063e2a1f8b18e..e2d25e36365fa 100644 --- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala +++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala @@ -35,6 +35,10 @@ private[spark] object ToolTips { val OUTPUT = "Bytes and records written to Hadoop." + val STORAGE_MEMORY = + "Memory used / total available memory for storage of data " + + "like RDD partitions cached in memory. " + val SHUFFLE_WRITE = "Bytes and records written to disk in order to be read by a shuffle in a future stage." diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index b247e4cdc3bd4..01cddda4c62cd 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -67,7 +67,7 @@ private[ui] class ExecutorsPage( Executor ID Address RDD Blocks - Memory Used + Storage Memory Disk Used Active Tasks Failed Tasks From e988adb58f02d06065837f3d79eee220f6558def Mon Sep 17 00:00:00 2001 From: Tom Graves Date: Thu, 25 Jun 2015 08:27:08 -0500 Subject: [PATCH 010/122] =?UTF-8?q?[SPARK-8574]=20org/apache/spark/unsafe?= =?UTF-8?q?=20doesn't=20honor=20the=20java=20source/ta=E2=80=A6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …rget versions. I basically copied the compatibility rules from the top level pom.xml into here. Someone more familiar with all the options in the top level pom may want to make sure nothing else should be copied on down. With this is allows me to build with jdk8 and run with lower versions. Source shows compiled for jdk6 as its supposed to. Author: Tom Graves Author: Thomas Graves Closes #6989 from tgravescs/SPARK-8574 and squashes the following commits: e1ea2d4 [Thomas Graves] Change to use combine.children="append" 150d645 [Tom Graves] [SPARK-8574] org/apache/spark/unsafe doesn't honor the java source/target versions --- unsafe/pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/unsafe/pom.xml b/unsafe/pom.xml index 62c6354f1e203..dd2ae6457f0b9 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -80,7 +80,7 @@ net.alchim31.maven scala-maven-plugin - + -XDignore.symbol.file From f9b397f54d1c491680d70aba210bb8211fd249c1 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Thu, 25 Jun 2015 06:52:03 -0700 Subject: [PATCH 011/122] [SPARK-8567] [SQL] Add logs to record the progress of HiveSparkSubmitSuite. Author: Yin Huai Closes #7009 from yhuai/SPARK-8567 and squashes the following commits: 62fb1f9 [Yin Huai] Add sc.stop(). b22cf7d [Yin Huai] Add logs. --- .../org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index b875e52b986ab..a38ed23b5cf9a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -115,6 +115,7 @@ object SparkSubmitClassLoaderTest extends Logging { val sc = new SparkContext(conf) val hiveContext = new TestHiveContext(sc) val df = hiveContext.createDataFrame((1 to 100).map(i => (i, i))).toDF("i", "j") + logInfo("Testing load classes at the driver side.") // First, we load classes at driver side. try { Class.forName(args(0), true, Thread.currentThread().getContextClassLoader) @@ -124,6 +125,7 @@ object SparkSubmitClassLoaderTest extends Logging { throw new Exception("Could not load user class from jar:\n", t) } // Second, we load classes at the executor side. + logInfo("Testing load classes at the executor side.") val result = df.mapPartitions { x => var exception: String = null try { @@ -141,6 +143,7 @@ object SparkSubmitClassLoaderTest extends Logging { } // Load a Hive UDF from the jar. + logInfo("Registering temporary Hive UDF provided in a jar.") hiveContext.sql( """ |CREATE TEMPORARY FUNCTION example_max @@ -150,18 +153,23 @@ object SparkSubmitClassLoaderTest extends Logging { hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") source.registerTempTable("sourceTable") // Load a Hive SerDe from the jar. + logInfo("Creating a Hive table with a SerDe provided in a jar.") hiveContext.sql( """ |CREATE TABLE t1(key int, val string) |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe' """.stripMargin) // Actually use the loaded UDF and SerDe. + logInfo("Writing data into the table.") hiveContext.sql( "INSERT INTO TABLE t1 SELECT example_max(key) as key, val FROM sourceTable GROUP BY val") + logInfo("Running a simple query on the table.") val count = hiveContext.table("t1").orderBy("key", "val").count() if (count != 10) { throw new Exception(s"table t1 should have 10 rows instead of $count rows") } + logInfo("Test finishes.") + sc.stop() } } @@ -199,5 +207,6 @@ object SparkSQLConfTest extends Logging { val hiveContext = new TestHiveContext(sc) // Run a simple command to make sure all lazy vals in hiveContext get instantiated. hiveContext.tables().collect() + sc.stop() } } From 2519dcc33bde3a6d341790d73b5d292ea7af961a Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Thu, 25 Jun 2015 08:13:17 -0700 Subject: [PATCH 012/122] [MINOR] [MLLIB] rename some functions of PythonMLLibAPI Keep the same naming conventions for PythonMLLibAPI. Only the following three functions is different from others ```scala trainNaiveBayes trainGaussianMixture trainWord2Vec ``` So change them to ```scala trainNaiveBayesModel trainGaussianMixtureModel trainWord2VecModel ``` It does not affect any users and public APIs, only to make better understand for developer and code hacker. Author: Yanbo Liang Closes #7011 from yanboliang/py-mllib-api-rename and squashes the following commits: 771ffec [Yanbo Liang] rename some functions of PythonMLLibAPI --- .../org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 6 +++--- python/pyspark/mllib/classification.py | 2 +- python/pyspark/mllib/clustering.py | 6 +++--- python/pyspark/mllib/feature.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index c4bea7c2cad4f..b16903a8d515c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -278,7 +278,7 @@ private[python] class PythonMLLibAPI extends Serializable { /** * Java stub for NaiveBayes.train() */ - def trainNaiveBayes( + def trainNaiveBayesModel( data: JavaRDD[LabeledPoint], lambda: Double): JList[Object] = { val model = NaiveBayes.train(data.rdd, lambda) @@ -346,7 +346,7 @@ private[python] class PythonMLLibAPI extends Serializable { * Java stub for Python mllib GaussianMixture.run() * Returns a list containing weights, mean and covariance of each mixture component. */ - def trainGaussianMixture( + def trainGaussianMixtureModel( data: JavaRDD[Vector], k: Int, convergenceTol: Double, @@ -553,7 +553,7 @@ private[python] class PythonMLLibAPI extends Serializable { * @param seed initial seed for random generator * @return A handle to java Word2VecModelWrapper instance at python side */ - def trainWord2Vec( + def trainWord2VecModel( dataJRDD: JavaRDD[java.util.ArrayList[String]], vectorSize: Int, learningRate: Double, diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 2698f10d06883..735d45ba03d27 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -581,7 +581,7 @@ def train(cls, data, lambda_=1.0): first = data.first() if not isinstance(first, LabeledPoint): raise ValueError("`data` should be an RDD of LabeledPoint") - labels, pi, theta = callMLlibFunc("trainNaiveBayes", data, lambda_) + labels, pi, theta = callMLlibFunc("trainNaiveBayesModel", data, lambda_) return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index e6ef72942ce77..8bc0654c76ca3 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -265,9 +265,9 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia initialModelWeights = initialModel.weights initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] - weight, mu, sigma = callMLlibFunc("trainGaussianMixture", rdd.map(_convert_to_vector), k, - convergenceTol, maxIterations, seed, initialModelWeights, - initialModelMu, initialModelSigma) + weight, mu, sigma = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), + k, convergenceTol, maxIterations, seed, + initialModelWeights, initialModelMu, initialModelSigma) mvg_obj = [MultivariateGaussian(mu[i], sigma[i]) for i in range(k)] return GaussianMixtureModel(weight, mvg_obj) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 334f5b86cd392..f00bb93b7bf40 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -549,7 +549,7 @@ def fit(self, data): """ if not isinstance(data, RDD): raise TypeError("data should be an RDD of list of string") - jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), + jmodel = callMLlibFunc("trainWord2VecModel", data, int(self.vectorSize), float(self.learningRate), int(self.numPartitions), int(self.numIterations), int(self.seed), int(self.minCount)) From c392a9efabcb1ec2a2c53f001ecdae33c245ba35 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 25 Jun 2015 10:56:00 -0700 Subject: [PATCH 013/122] [SPARK-8637] [SPARKR] [HOTFIX] Fix packages argument, sparkSubmitBinName cc cafreeman Author: Shivaram Venkataraman Closes #7022 from shivaram/sparkr-init-hotfix and squashes the following commits: 9178d15 [Shivaram Venkataraman] Fix packages argument, sparkSubmitBinName --- R/pkg/R/client.R | 2 +- R/pkg/R/sparkR.R | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/R/pkg/R/client.R b/R/pkg/R/client.R index cf2e5ddeb7a9d..78c7a3037ffac 100644 --- a/R/pkg/R/client.R +++ b/R/pkg/R/client.R @@ -57,7 +57,7 @@ generateSparkSubmitArgs <- function(args, sparkHome, jars, sparkSubmitOpts, pack } launchBackend <- function(args, sparkHome, jars, sparkSubmitOpts, packages) { - sparkSubmitBin <- determineSparkSubmitBin() + sparkSubmitBinName <- determineSparkSubmitBin() if (sparkHome != "") { sparkSubmitBin <- file.path(sparkHome, "bin", sparkSubmitBinName) } else { diff --git a/R/pkg/R/sparkR.R b/R/pkg/R/sparkR.R index 8f81d5640c1d0..633b869f91784 100644 --- a/R/pkg/R/sparkR.R +++ b/R/pkg/R/sparkR.R @@ -132,7 +132,7 @@ sparkR.init <- function( sparkHome = sparkHome, jars = jars, sparkSubmitOpts = Sys.getenv("SPARKR_SUBMIT_ARGS", "sparkr-shell"), - sparkPackages = sparkPackages) + packages = sparkPackages) # wait atmost 100 seconds for JVM to launch wait <- 0.1 for (i in 1:25) { From 47c874babe7779c7a2f32e0b891503ef6bebcab0 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 25 Jun 2015 22:07:37 -0700 Subject: [PATCH 014/122] [SPARK-8237] [SQL] Add misc function sha2 JIRA: https://issues.apache.org/jira/browse/SPARK-8237 Author: Liang-Chi Hsieh Closes #6934 from viirya/expr_sha2 and squashes the following commits: 35e0bb3 [Liang-Chi Hsieh] For comments. 68b5284 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_sha2 8573aff [Liang-Chi Hsieh] Remove unnecessary Product. ee61e06 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into expr_sha2 59e41aa [Liang-Chi Hsieh] Add misc function: sha2. --- python/pyspark/sql/functions.py | 19 ++++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/misc.scala | 98 ++++++++++++++++++- .../expressions/MiscFunctionsSuite.scala | 14 ++- .../org/apache/spark/sql/functions.scala | 20 ++++ .../spark/sql/DataFrameFunctionsSuite.scala | 17 ++++ 6 files changed, 165 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index cfa87aeea193a..7d3d0361610b7 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -42,6 +42,7 @@ 'monotonicallyIncreasingId', 'rand', 'randn', + 'sha2', 'sparkPartitionId', 'struct', 'udf', @@ -363,6 +364,24 @@ def randn(seed=None): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def sha2(col, numBits): + """Returns the hex string result of SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384, + and SHA-512). The numBits indicates the desired bit length of the result, which must have a + value of 224, 256, 384, 512, or 0 (which is equivalent to 256). + + >>> digests = df.select(sha2(df.name, 256).alias('s')).collect() + >>> digests[0] + Row(s=u'3bc51062973c458d5a6f2d8d64a023246354ad7e064b1e4e009ec8a0699a3043') + >>> digests[1] + Row(s=u'cd9fb1e148ccd8442e5aa74904cc73bf6fb54d1d54d333bd596aa9bb4bb4e961') + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha2(_to_java_column(col), numBits) + return Column(jc) + + @since(1.4) def sparkPartitionId(): """A column for partition ID of the Spark task. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 5fb3369f85d12..457948a800a17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -135,6 +135,7 @@ object FunctionRegistry { // misc functions expression[Md5]("md5"), + expression[Sha2]("sha2"), // aggregate functions expression[Average]("avg"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 4bee8cb728e5c..e80706fc65aff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -17,9 +17,12 @@ package org.apache.spark.sql.catalyst.expressions +import java.security.MessageDigest +import java.security.NoSuchAlgorithmException + import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types.{BinaryType, StringType, DataType} +import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType, DataType} import org.apache.spark.unsafe.types.UTF8String /** @@ -44,7 +47,96 @@ case class Md5(child: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => - "org.apache.spark.unsafe.types.UTF8String.fromString" + - s"(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") + s"${ctx.stringType}.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") + } +} + +/** + * A function that calculates the SHA-2 family of functions (SHA-224, SHA-256, SHA-384, and SHA-512) + * and returns it as a hex string. The first argument is the string or binary to be hashed. The + * second argument indicates the desired bit length of the result, which must have a value of 224, + * 256, 384, 512, or 0 (which is equivalent to 256). SHA-224 is supported starting from Java 8. If + * asking for an unsupported SHA function, the return value is NULL. If either argument is NULL or + * the hash length is not one of the permitted values, the return value is NULL. + */ +case class Sha2(left: Expression, right: Expression) + extends BinaryExpression with Serializable with ExpectsInputTypes { + + override def dataType: DataType = StringType + + override def toString: String = s"SHA2($left, $right)" + + override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, IntegerType) + + override def eval(input: InternalRow): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + val bitLength = evalE2.asInstanceOf[Int] + val input = evalE1.asInstanceOf[Array[Byte]] + bitLength match { + case 224 => + // DigestUtils doesn't support SHA-224 now + try { + val md = MessageDigest.getInstance("SHA-224") + md.update(input) + UTF8String.fromBytes(md.digest()) + } catch { + // SHA-224 is not supported on the system, return null + case noa: NoSuchAlgorithmException => null + } + case 256 | 0 => + UTF8String.fromString(DigestUtils.sha256Hex(input)) + case 384 => + UTF8String.fromString(DigestUtils.sha384Hex(input)) + case 512 => + UTF8String.fromString(DigestUtils.sha512Hex(input)) + case _ => null + } + } + } + } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val eval1 = left.gen(ctx) + val eval2 = right.gen(ctx) + val digestUtils = "org.apache.commons.codec.digest.DigestUtils" + + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + if (${eval2.primitive} == 224) { + try { + java.security.MessageDigest md = java.security.MessageDigest.getInstance("SHA-224"); + md.update(${eval1.primitive}); + ${ev.primitive} = ${ctx.stringType}.fromBytes(md.digest()); + } catch (java.security.NoSuchAlgorithmException e) { + ${ev.isNull} = true; + } + } else if (${eval2.primitive} == 256 || ${eval2.primitive} == 0) { + ${ev.primitive} = + ${ctx.stringType}.fromString(${digestUtils}.sha256Hex(${eval1.primitive})); + } else if (${eval2.primitive} == 384) { + ${ev.primitive} = + ${ctx.stringType}.fromString(${digestUtils}.sha384Hex(${eval1.primitive})); + } else if (${eval2.primitive} == 512) { + ${ev.primitive} = + ${ctx.stringType}.fromString(${digestUtils}.sha512Hex(${eval1.primitive})); + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } + """ } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 48b84130b4556..38482c54c61db 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.commons.codec.digest.DigestUtils + import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{StringType, BinaryType} +import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType} class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -29,4 +31,14 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Md5(Literal.create(null, BinaryType)), null) } + test("sha2") { + checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC")) + checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)), + DigestUtils.sha384Hex(Array[Byte](1, 2, 3, 4, 5, 6))) + // unsupported bit length + checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(1024)), null) + checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal(512)), null) + checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null) + checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 38d9085a505fb..355ce0e3423cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1414,6 +1414,26 @@ object functions { */ def md5(columnName: String): Column = md5(Column(columnName)) + /** + * Calculates the SHA-2 family of hash functions and returns the value as a hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha2(e: Column, numBits: Int): Column = { + require(Seq(0, 224, 256, 384, 512).contains(numBits), + s"numBits $numBits is not in the permitted values (0, 224, 256, 384, 512)") + Sha2(e.expr, lit(numBits).expr) + } + + /** + * Calculates the SHA-2 family of hash functions and returns the value as a hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha2(columnName: String, numBits: Int): Column = sha2(Column(columnName), numBits) + ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 8b53b384a22fd..8baed57a7f129 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -144,6 +144,23 @@ class DataFrameFunctionsSuite extends QueryTest { Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) } + test("misc sha2 function") { + val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") + checkAnswer( + df.select(sha2($"a", 256), sha2("b", 256)), + Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78", + "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89")) + + checkAnswer( + df.selectExpr("sha2(a, 256)", "sha2(b, 256)"), + Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78", + "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89")) + + intercept[IllegalArgumentException] { + df.select(sha2($"a", 1024)) + } + } + test("string length function") { checkAnswer( nullStrings.select(strlen($"s"), strlen("s")), From 40360112c417b5432564f4bcb8a9100f4066b55e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 25 Jun 2015 22:16:53 -0700 Subject: [PATCH 015/122] [SPARK-8620] [SQL] cleanup CodeGenContext fix docs, remove nativeTypes , use java type to get boxed type ,default value, etc. to avoid handle `DateType` and `TimestampType` as int and long again and again. Author: Wenchen Fan Closes #7010 from cloud-fan/cg and squashes the following commits: aa01cf9 [Wenchen Fan] cleanup CodeGenContext --- .../spark/sql/catalyst/expressions/Cast.scala | 5 +- .../expressions/codegen/CodeGenerator.scala | 130 +++++++++--------- .../codegen/GenerateProjection.scala | 34 ++--- .../expressions/stringOperations.scala | 1 - 4 files changed, 82 insertions(+), 88 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8bd7fc18a8dd4..8d66968a2fc35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -467,11 +467,8 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w defineCodeGen(ctx, ev, c => s"!$c.isZero()") case (dt: NumericType, BooleanType) => defineCodeGen(ctx, ev, c => s"$c != 0") - - case (_: DecimalType, IntegerType) => - defineCodeGen(ctx, ev, c => s"($c).toInt()") case (_: DecimalType, dt: NumericType) => - defineCodeGen(ctx, ev, c => s"($c).to${ctx.boxedType(dt)}()") + defineCodeGen(ctx, ev, c => s"($c).to${ctx.primitiveTypeName(dt)}()") case (_: NumericType, dt: NumericType) => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 47c5455435ec6..e20e3a9dca502 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -59,6 +59,14 @@ class CodeGenContext { val stringType: String = classOf[UTF8String].getName val decimalType: String = classOf[Decimal].getName + final val JAVA_BOOLEAN = "boolean" + final val JAVA_BYTE = "byte" + final val JAVA_SHORT = "short" + final val JAVA_INT = "int" + final val JAVA_LONG = "long" + final val JAVA_FLOAT = "float" + final val JAVA_DOUBLE = "double" + private val curId = new java.util.concurrent.atomic.AtomicInteger() /** @@ -72,98 +80,94 @@ class CodeGenContext { } /** - * Return the code to access a column for given DataType + * Returns the code to access a column in Row for a given DataType. */ def getColumn(dataType: DataType, ordinal: Int): String = { - if (isNativeType(dataType)) { - s"i.${accessorForType(dataType)}($ordinal)" + val jt = javaType(dataType) + if (isPrimitiveType(jt)) { + s"i.get${primitiveTypeName(jt)}($ordinal)" } else { - s"(${boxedType(dataType)})i.apply($ordinal)" + s"($jt)i.apply($ordinal)" } } /** - * Return the code to update a column in Row for given DataType + * Returns the code to update a column in Row for a given DataType. */ def setColumn(dataType: DataType, ordinal: Int, value: String): String = { - if (isNativeType(dataType)) { - s"${mutatorForType(dataType)}($ordinal, $value)" + val jt = javaType(dataType) + if (isPrimitiveType(jt)) { + s"set${primitiveTypeName(jt)}($ordinal, $value)" } else { s"update($ordinal, $value)" } } /** - * Return the name of accessor in Row for a DataType + * Returns the name used in accessor and setter for a Java primitive type. */ - def accessorForType(dt: DataType): String = dt match { - case IntegerType => "getInt" - case other => s"get${boxedType(dt)}" + def primitiveTypeName(jt: String): String = jt match { + case JAVA_INT => "Int" + case _ => boxedType(jt) } - /** - * Return the name of mutator in Row for a DataType - */ - def mutatorForType(dt: DataType): String = dt match { - case IntegerType => "setInt" - case other => s"set${boxedType(dt)}" - } + def primitiveTypeName(dt: DataType): String = primitiveTypeName(javaType(dt)) /** - * Return the Java type for a DataType + * Returns the Java type for a DataType. */ def javaType(dt: DataType): String = dt match { - case IntegerType => "int" - case LongType => "long" - case ShortType => "short" - case ByteType => "byte" - case DoubleType => "double" - case FloatType => "float" - case BooleanType => "boolean" + case BooleanType => JAVA_BOOLEAN + case ByteType => JAVA_BYTE + case ShortType => JAVA_SHORT + case IntegerType => JAVA_INT + case LongType => JAVA_LONG + case FloatType => JAVA_FLOAT + case DoubleType => JAVA_DOUBLE case dt: DecimalType => decimalType case BinaryType => "byte[]" case StringType => stringType - case DateType => "int" - case TimestampType => "long" + case DateType => JAVA_INT + case TimestampType => JAVA_LONG case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" } /** - * Return the boxed type in Java + * Returns the boxed type in Java. */ - def boxedType(dt: DataType): String = dt match { - case IntegerType => "Integer" - case LongType => "Long" - case ShortType => "Short" - case ByteType => "Byte" - case DoubleType => "Double" - case FloatType => "Float" - case BooleanType => "Boolean" - case DateType => "Integer" - case TimestampType => "Long" - case _ => javaType(dt) + def boxedType(jt: String): String = jt match { + case JAVA_BOOLEAN => "Boolean" + case JAVA_BYTE => "Byte" + case JAVA_SHORT => "Short" + case JAVA_INT => "Integer" + case JAVA_LONG => "Long" + case JAVA_FLOAT => "Float" + case JAVA_DOUBLE => "Double" + case other => other } + def boxedType(dt: DataType): String = boxedType(javaType(dt)) + /** - * Return the representation of default value for given DataType + * Returns the representation of default value for a given Java Type. */ - def defaultValue(dt: DataType): String = dt match { - case BooleanType => "false" - case FloatType => "-1.0f" - case ShortType => "(short)-1" - case LongType => "-1L" - case ByteType => "(byte)-1" - case DoubleType => "-1.0" - case IntegerType => "-1" - case DateType => "-1" - case TimestampType => "-1L" + def defaultValue(jt: String): String = jt match { + case JAVA_BOOLEAN => "false" + case JAVA_BYTE => "(byte)-1" + case JAVA_SHORT => "(short)-1" + case JAVA_INT => "-1" + case JAVA_LONG => "-1L" + case JAVA_FLOAT => "-1.0f" + case JAVA_DOUBLE => "-1.0" case _ => "null" } + def defaultValue(dt: DataType): String = defaultValue(javaType(dt)) + /** - * Generate code for equal expression in Java + * Generates code for equal expression in Java. */ def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match { case BinaryType => s"java.util.Arrays.equals($c1, $c2)" @@ -172,7 +176,7 @@ class CodeGenContext { } /** - * Generate code for compare expression in Java + * Generates code for compare expression in Java. */ def genComp(dataType: DataType, c1: String, c2: String): String = dataType match { // java boolean doesn't support > or < operator @@ -184,25 +188,17 @@ class CodeGenContext { } /** - * List of data types that have special accessors and setters in [[InternalRow]]. + * List of java data types that have special accessors and setters in [[InternalRow]]. */ - val nativeTypes = - Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType) + val primitiveTypes = + Seq(JAVA_BOOLEAN, JAVA_BYTE, JAVA_SHORT, JAVA_INT, JAVA_LONG, JAVA_FLOAT, JAVA_DOUBLE) /** - * Returns true if the data type has a special accessor and setter in [[InternalRow]]. + * Returns true if the Java type has a special accessor and setter in [[InternalRow]]. */ - def isNativeType(dt: DataType): Boolean = nativeTypes.contains(dt) + def isPrimitiveType(jt: String): Boolean = primitiveTypes.contains(jt) - /** - * List of data types who's Java type is primitive type - */ - val primitiveTypes = nativeTypes ++ Seq(DateType, TimestampType) - - /** - * Returns true if the Java type is primitive type - */ - def isPrimitiveType(dt: DataType): Boolean = primitiveTypes.contains(dt) + def isPrimitiveType(dt: DataType): Boolean = isPrimitiveType(javaType(dt)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index e362625469e29..624e1cf4e201a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -72,54 +72,56 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { s"case $i: { c$i = (${ctx.boxedType(e.dataType)})value; return;}" }.mkString("\n ") - val specificAccessorFunctions = ctx.nativeTypes.map { dataType => + val specificAccessorFunctions = ctx.primitiveTypes.map { jt => val cases = expressions.zipWithIndex.flatMap { - case (e, i) if ctx.javaType(e.dataType) == ctx.javaType(dataType) => - List(s"case $i: return c$i;") - case _ => Nil + case (e, i) if ctx.javaType(e.dataType) == jt => + Some(s"case $i: return c$i;") + case _ => None }.mkString("\n ") if (cases.length > 0) { + val getter = "get" + ctx.primitiveTypeName(jt) s""" @Override - public ${ctx.javaType(dataType)} ${ctx.accessorForType(dataType)}(int i) { + public $jt $getter(int i) { if (isNullAt(i)) { - return ${ctx.defaultValue(dataType)}; + return ${ctx.defaultValue(jt)}; } switch (i) { $cases } throw new IllegalArgumentException("Invalid index: " + i - + " in ${ctx.accessorForType(dataType)}"); + + " in $getter"); }""" } else { "" } - }.mkString("\n") + }.filter(_.length > 0).mkString("\n") - val specificMutatorFunctions = ctx.nativeTypes.map { dataType => + val specificMutatorFunctions = ctx.primitiveTypes.map { jt => val cases = expressions.zipWithIndex.flatMap { - case (e, i) if ctx.javaType(e.dataType) == ctx.javaType(dataType) => - List(s"case $i: { c$i = value; return; }") - case _ => Nil + case (e, i) if ctx.javaType(e.dataType) == jt => + Some(s"case $i: { c$i = value; return; }") + case _ => None }.mkString("\n ") if (cases.length > 0) { + val setter = "set" + ctx.primitiveTypeName(jt) s""" @Override - public void ${ctx.mutatorForType(dataType)}(int i, ${ctx.javaType(dataType)} value) { + public void $setter(int i, $jt value) { nullBits[i] = false; switch (i) { $cases } throw new IllegalArgumentException("Invalid index: " + i + - " in ${ctx.mutatorForType(dataType)}"); + " in $setter}"); }""" } else { "" } - }.mkString("\n") + }.filter(_.length > 0).mkString("\n") val hashValues = expressions.zipWithIndex.map { case (e, i) => - val col = newTermName(s"c$i") + val col = s"c$i" val nonNull = e.dataType match { case BooleanType => s"$col ? 0 : 1" case ByteType | ShortType | IntegerType | DateType => s"$col" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 44416e79cd7aa..a6225fdafedde 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions import java.util.regex.Pattern import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.catalyst.expressions.Substring import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String From 1a79f0eb8da7e850c443383b3bb24e0bf8e1e7cb Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 25 Jun 2015 22:44:26 -0700 Subject: [PATCH 016/122] [SPARK-8635] [SQL] improve performance of CatalystTypeConverters In `CatalystTypeConverters.createToCatalystConverter`, we add special handling for primitive types. We can apply this strategy to more places to improve performance. Author: Wenchen Fan Closes #7018 from cloud-fan/converter and squashes the following commits: 8b16630 [Wenchen Fan] another fix 326c82c [Wenchen Fan] optimize type converter --- .../sql/catalyst/CatalystTypeConverters.scala | 60 ++++++++++++------- .../sql/catalyst/expressions/ScalaUdf.scala | 3 +- .../org/apache/spark/sql/DataFrame.scala | 4 +- .../sql/execution/stat/FrequentItems.scala | 2 +- .../sql/execution/stat/StatFunctions.scala | 2 +- .../sql/sources/DataSourceStrategy.scala | 2 +- .../apache/spark/sql/sources/commands.scala | 4 +- .../spark/sql/sources/TableScanSuite.scala | 4 +- 8 files changed, 48 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 429fc4077be9a..012f8bbecb4d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -52,6 +52,13 @@ object CatalystTypeConverters { } } + private def isWholePrimitive(dt: DataType): Boolean = dt match { + case dt if isPrimitive(dt) => true + case ArrayType(elementType, _) => isWholePrimitive(elementType) + case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType) + case _ => false + } + private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = { val converter = dataType match { case udt: UserDefinedType[_] => UDTConverter(udt) @@ -148,6 +155,8 @@ object CatalystTypeConverters { private[this] val elementConverter = getConverterForType(elementType) + private[this] val isNoChange = isWholePrimitive(elementType) + override def toCatalystImpl(scalaValue: Any): Seq[Any] = { scalaValue match { case a: Array[_] => a.toSeq.map(elementConverter.toCatalyst) @@ -166,8 +175,10 @@ object CatalystTypeConverters { override def toScala(catalystValue: Seq[Any]): Seq[Any] = { if (catalystValue == null) { null + } else if (isNoChange) { + catalystValue } else { - catalystValue.asInstanceOf[Seq[_]].map(elementConverter.toScala) + catalystValue.map(elementConverter.toScala) } } @@ -183,6 +194,8 @@ object CatalystTypeConverters { private[this] val keyConverter = getConverterForType(keyType) private[this] val valueConverter = getConverterForType(valueType) + private[this] val isNoChange = isWholePrimitive(keyType) && isWholePrimitive(valueType) + override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match { case m: Map[_, _] => m.map { case (k, v) => @@ -203,6 +216,8 @@ object CatalystTypeConverters { override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = { if (catalystValue == null) { null + } else if (isNoChange) { + catalystValue } else { catalystValue.map { case (k, v) => keyConverter.toScala(k) -> valueConverter.toScala(v) @@ -258,16 +273,13 @@ object CatalystTypeConverters { toScala(row(column).asInstanceOf[InternalRow]) } - private object StringConverter extends CatalystTypeConverter[Any, String, Any] { + private object StringConverter extends CatalystTypeConverter[Any, String, UTF8String] { override def toCatalystImpl(scalaValue: Any): UTF8String = scalaValue match { case str: String => UTF8String.fromString(str) case utf8: UTF8String => utf8 } - override def toScala(catalystValue: Any): String = catalystValue match { - case null => null - case str: String => str - case utf8: UTF8String => utf8.toString() - } + override def toScala(catalystValue: UTF8String): String = + if (catalystValue == null) null else catalystValue.toString override def toScalaImpl(row: InternalRow, column: Int): String = row(column).toString } @@ -275,7 +287,8 @@ object CatalystTypeConverters { override def toCatalystImpl(scalaValue: Date): Int = DateTimeUtils.fromJavaDate(scalaValue) override def toScala(catalystValue: Any): Date = if (catalystValue == null) null else DateTimeUtils.toJavaDate(catalystValue.asInstanceOf[Int]) - override def toScalaImpl(row: InternalRow, column: Int): Date = toScala(row.getInt(column)) + override def toScalaImpl(row: InternalRow, column: Int): Date = + DateTimeUtils.toJavaDate(row.getInt(column)) } private object TimestampConverter extends CatalystTypeConverter[Timestamp, Timestamp, Any] { @@ -285,7 +298,7 @@ object CatalystTypeConverters { if (catalystValue == null) null else DateTimeUtils.toJavaTimestamp(catalystValue.asInstanceOf[Long]) override def toScalaImpl(row: InternalRow, column: Int): Timestamp = - toScala(row.getLong(column)) + DateTimeUtils.toJavaTimestamp(row.getLong(column)) } private object BigDecimalConverter extends CatalystTypeConverter[Any, JavaBigDecimal, Decimal] { @@ -296,10 +309,7 @@ object CatalystTypeConverters { } override def toScala(catalystValue: Decimal): JavaBigDecimal = catalystValue.toJavaBigDecimal override def toScalaImpl(row: InternalRow, column: Int): JavaBigDecimal = - row.get(column) match { - case d: JavaBigDecimal => d - case d: Decimal => d.toJavaBigDecimal - } + row.get(column).asInstanceOf[Decimal].toJavaBigDecimal } private abstract class PrimitiveConverter[T] extends CatalystTypeConverter[T, Any, Any] { @@ -362,6 +372,19 @@ object CatalystTypeConverters { } } + /** + * Creates a converter function that will convert Catalyst types to Scala type. + * Typical use case would be converting a collection of rows that have the same schema. You will + * call this function once to get a converter, and apply it to every row. + */ + private[sql] def createToScalaConverter(dataType: DataType): Any => Any = { + if (isPrimitive(dataType)) { + identity + } else { + getConverterForType(dataType).toScala + } + } + /** * Converts Scala objects to Catalyst rows / types. * @@ -389,15 +412,6 @@ object CatalystTypeConverters { * produced by createToScalaConverter. */ def convertToScala(catalystValue: Any, dataType: DataType): Any = { - getConverterForType(dataType).toScala(catalystValue) - } - - /** - * Creates a converter function that will convert Catalyst types to Scala type. - * Typical use case would be converting a collection of rows that have the same schema. You will - * call this function once to get a converter, and apply it to every row. - */ - private[sql] def createToScalaConverter(dataType: DataType): Any => Any = { - getConverterForType(dataType).toScala + createToScalaConverter(dataType)(catalystValue) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 3992f1f59dad8..55df72f102295 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.CatalystTypeConverters import org.apache.spark.sql.types.DataType @@ -39,7 +38,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi (1 to 22).map { x => val anys = (1 to x).map(x => "Any").reduce(_ + ", " + _) val childs = (0 to x - 1).map(x => s"val child$x = children($x)").reduce(_ + "\n " + _) - lazy val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _) + val converters = (0 to x - 1).map(x => s"lazy val converter$x = CatalystTypeConverters.createToScalaConverter(child$x.dataType)").reduce(_ + "\n " + _) val evals = (0 to x - 1).map(x => s"converter$x(child$x.eval(input))").reduce(_ + ",\n " + _) s"""case $x => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index f3f0f5305318e..0db4df34f9e22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1418,12 +1418,14 @@ class DataFrame private[sql]( lazy val rdd: RDD[Row] = { // use a local variable to make sure the map closure doesn't capture the whole DataFrame val schema = this.schema - queryExecution.executedPlan.execute().mapPartitions { rows => + internalRowRdd.mapPartitions { rows => val converter = CatalystTypeConverters.createToScalaConverter(schema) rows.map(converter(_).asInstanceOf[Row]) } } + private[sql] def internalRowRdd = queryExecution.executedPlan.execute() + /** * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. * @group rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 8df1da037c434..3ebbf96090a55 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -90,7 +90,7 @@ private[sql] object FrequentItems extends Logging { (name, originalSchema.fields(index).dataType) } - val freqItems = df.select(cols.map(Column(_)) : _*).rdd.aggregate(countMaps)( + val freqItems = df.select(cols.map(Column(_)) : _*).internalRowRdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 93383e5a62f11..252c611d02ebc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -81,7 +81,7 @@ private[sql] object StatFunctions extends Logging { s"with dataType ${data.get.dataType} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) - df.select(columns: _*).rdd.aggregate(new CovarianceCounter)( + df.select(columns: _*).internalRowRdd.aggregate(new CovarianceCounter)( seqOp = (counter, row) => { counter.add(row.getDouble(0), row.getDouble(1)) }, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index a8f56f4767407..ce16e050c56ed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -313,7 +313,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { output: Seq[Attribute], rdd: RDD[Row]): RDD[InternalRow] = { if (relation.relation.needConversion) { - execution.RDDConversions.rowToRowRdd(rdd.asInstanceOf[RDD[Row]], output.map(_.dataType)) + execution.RDDConversions.rowToRowRdd(rdd, output.map(_.dataType)) } else { rdd.map(_.asInstanceOf[InternalRow]) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index fb6173f58ece6..dbb369cf45502 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -154,7 +154,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.driverSideSetup() try { - df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _) + df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _) writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => @@ -220,7 +220,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.driverSideSetup() try { - df.sqlContext.sparkContext.runJob(df.queryExecution.executedPlan.execute(), writeRows _) + df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _) writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 79eac930e54f7..de0ed0c0427a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -88,9 +88,9 @@ case class AllDataTypesScan( UTF8String.fromString(s"varchar_$i"), Seq(i, i + 1), Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))), - Map(i -> i.toString), + Map(i -> UTF8String.fromString(i.toString)), Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)), - Row(i, i.toString), + Row(i, UTF8String.fromString(i.toString)), Row(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")), InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1)))))) } From 9fed6abfdcb7afcf92be56e5ccbed6599fe66bc4 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 26 Jun 2015 00:12:05 -0700 Subject: [PATCH 017/122] [SPARK-8344] Add message processing time metric to DAGScheduler This commit adds a new metric, `messageProcessingTime`, to the DAGScheduler metrics source. This metrics tracks the time taken to process messages in the scheduler's event processing loop, which is a helpful debugging aid for diagnosing performance issues in the scheduler (such as SPARK-4961). In order to do this, I moved the creation of the DAGSchedulerSource metrics source into DAGScheduler itself, similar to how MasterSource is created and registered in Master. Author: Josh Rosen Closes #7002 from JoshRosen/SPARK-8344 and squashes the following commits: 57f914b [Josh Rosen] Fix import ordering 7d6bb83 [Josh Rosen] Add message processing time metrics to DAGScheduler --- .../scala/org/apache/spark/SparkContext.scala | 1 - .../apache/spark/scheduler/DAGScheduler.scala | 18 ++++++++++++++++-- .../spark/scheduler/DAGSchedulerSource.scala | 8 ++++++-- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 141276ac901fb..c7a7436462083 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -545,7 +545,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Post init _taskScheduler.postStartHook() - _env.metricsSystem.registerSource(new DAGSchedulerSource(dagScheduler)) _env.metricsSystem.registerSource(new BlockManagerSource(_env.blockManager)) _executorAllocationManager.foreach { e => _env.metricsSystem.registerSource(e.executorAllocationManagerSource) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index aea6674ed20be..b00a5fee09bf2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -81,6 +81,8 @@ class DAGScheduler( def this(sc: SparkContext) = this(sc, sc.taskScheduler) + private[scheduler] val metricsSource: DAGSchedulerSource = new DAGSchedulerSource(this) + private[scheduler] val nextJobId = new AtomicInteger(0) private[scheduler] def numTotalJobs: Int = nextJobId.get() private val nextStageId = new AtomicInteger(0) @@ -1438,17 +1440,29 @@ class DAGScheduler( taskScheduler.stop() } - // Start the event thread at the end of the constructor + // Start the event thread and register the metrics source at the end of the constructor + env.metricsSystem.registerSource(metricsSource) eventProcessLoop.start() } private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler) extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging { + private[this] val timer = dagScheduler.metricsSource.messageProcessingTimer + /** * The main event loop of the DAG scheduler. */ - override def onReceive(event: DAGSchedulerEvent): Unit = event match { + override def onReceive(event: DAGSchedulerEvent): Unit = { + val timerContext = timer.time() + try { + doOnReceive(event) + } finally { + timerContext.stop() + } + } + + private def doOnReceive(event: DAGSchedulerEvent): Unit = event match { case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala index 02c67073af6a0..6b667d5d7645b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerSource.scala @@ -17,11 +17,11 @@ package org.apache.spark.scheduler -import com.codahale.metrics.{Gauge, MetricRegistry} +import com.codahale.metrics.{Gauge, MetricRegistry, Timer} import org.apache.spark.metrics.source.Source -private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) +private[scheduler] class DAGSchedulerSource(val dagScheduler: DAGScheduler) extends Source { override val metricRegistry = new MetricRegistry() override val sourceName = "DAGScheduler" @@ -45,4 +45,8 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler) metricRegistry.register(MetricRegistry.name("job", "activeJobs"), new Gauge[Int] { override def getValue: Int = dagScheduler.activeJobs.size }) + + /** Timer that tracks the time to process messages in the DAGScheduler's event loop */ + val messageProcessingTimer: Timer = + metricRegistry.timer(MetricRegistry.name("messageProcessingTime")) } From c9e05a315a96fbf3026a2b3c6934dd2dec420099 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 26 Jun 2015 01:19:05 -0700 Subject: [PATCH 018/122] [SPARK-8613] [ML] [TRIVIAL] add param to disable linear feature scaling Add a param to disable linear feature scaling (to be implemented later in linear & logistic regression). Done as a seperate PR so we can use same param & not conflict while working on the sub-tasks. Author: Holden Karau Closes #7024 from holdenk/SPARK-8522-Disable-Linear_featureScaling-Spark-8613-Add-param and squashes the following commits: ce8931a [Holden Karau] Regenerate the sharedParams code fa6427e [Holden Karau] update text for standardization param. 7b24a2b [Holden Karau] generate the new standardization param 3c190af [Holden Karau] Add the standardization param to sharedparamscodegen --- .../ml/param/shared/SharedParamsCodeGen.scala | 3 +++ .../spark/ml/param/shared/sharedParams.scala | 17 +++++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 8ffbcf0d8bc71..b0a6af171c01f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -53,6 +53,9 @@ private[shared] object SharedParamsCodeGen { ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)", isValid = "ParamValidators.gtEq(1)"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), + ParamDesc[Boolean]("standardization", "whether to standardize the training features" + + " prior to fitting the model sequence. Note that the coefficients of models are" + + " always returned on the original scale.", Some("true")), ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")), ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]." + " For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index a0c8ccdac9ad9..bbe08939b6d75 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -233,6 +233,23 @@ private[ml] trait HasFitIntercept extends Params { final def getFitIntercept: Boolean = $(fitIntercept) } +/** + * (private[ml]) Trait for shared param standardization (default: true). + */ +private[ml] trait HasStandardization extends Params { + + /** + * Param for whether to standardize the training features prior to fitting the model sequence. Note that the coefficients of models are always returned on the original scale.. + * @group param + */ + final val standardization: BooleanParam = new BooleanParam(this, "standardization", "whether to standardize the training features prior to fitting the model sequence. Note that the coefficients of models are always returned on the original scale.") + + setDefault(standardization, true) + + /** @group getParam */ + final def getStandardization: Boolean = $(standardization) +} + /** * (private[ml]) Trait for shared param seed (default: this.getClass.getName.hashCode.toLong). */ From 37bf76a2de2143ec6348a3d43b782227849520cc Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 26 Jun 2015 08:45:22 -0500 Subject: [PATCH 019/122] [SPARK-8302] Support heterogeneous cluster install paths on YARN. Some users have Hadoop installations on different paths across their cluster. Currently, that makes it hard to set up some configuration in Spark since that requires hardcoding paths to jar files or native libraries, which wouldn't work on such a cluster. This change introduces a couple of YARN-specific configurations that instruct the backend to replace certain paths when launching remote processes. That way, if the configuration says the Spark jar is in "/spark/spark.jar", and also says that "/spark" should be replaced with "{{SPARK_INSTALL_DIR}}", YARN will start containers in the NMs with "{{SPARK_INSTALL_DIR}}/spark.jar" as the location of the jar. Coupled with YARN's environment whitelist (which allows certain env variables to be exposed to containers), this allows users to support such heterogeneous environments, as long as a single replacement is enough. (Otherwise, this feature would need to be extended to support multiple path replacements.) Author: Marcelo Vanzin Closes #6752 from vanzin/SPARK-8302 and squashes the following commits: 4bff8d4 [Marcelo Vanzin] Add docs, rename configs. 0aa2a02 [Marcelo Vanzin] Only do replacement for paths that need it. 2e9cc9d [Marcelo Vanzin] Style. a5e1f68 [Marcelo Vanzin] [SPARK-8302] Support heterogeneous cluster install paths on YARN. --- docs/running-on-yarn.md | 26 ++++++++++ .../org/apache/spark/deploy/yarn/Client.scala | 47 +++++++++++++++---- .../spark/deploy/yarn/ExecutorRunnable.scala | 4 +- .../spark/deploy/yarn/ClientSuite.scala | 19 ++++++++ 4 files changed, 84 insertions(+), 12 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 96cf612c54fdd..3f8a093bbe957 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -258,6 +258,32 @@ Most of the configs are the same for Spark on YARN as for other deployment modes Principal to be used to login to KDC, while running on secure HDFS. + + spark.yarn.config.gatewayPath + (none) + + A path that is valid on the gateway host (the host where a Spark application is started) but may + differ for paths for the same resource in other nodes in the cluster. Coupled with + spark.yarn.config.replacementPath, this is used to support clusters with + heterogeneous configurations, so that Spark can correctly launch remote processes. +

+ The replacement path normally will contain a reference to some environment variable exported by + YARN (and, thus, visible to Spark containers). +

+ For example, if the gateway node has Hadoop libraries installed on /disk1/hadoop, and + the location of the Hadoop install is exported by YARN as the HADOOP_HOME + environment variable, setting this value to /disk1/hadoop and the replacement path to + $HADOOP_HOME will make sure that paths used to launch remote processes properly + reference the local YARN configuration. + + + + spark.yarn.config.replacementPath + (none) + + See spark.yarn.config.gatewayPath. + + # Launching Spark on YARN diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index da1ec2a0fe2e9..67a5c95400e53 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -676,7 +676,7 @@ private[spark] class Client( val libraryPaths = Seq(sys.props.get("spark.driver.extraLibraryPath"), sys.props.get("spark.driver.libraryPath")).flatten if (libraryPaths.nonEmpty) { - prefixEnv = Some(Utils.libraryPathEnvPrefix(libraryPaths)) + prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(libraryPaths))) } if (sparkConf.getOption("spark.yarn.am.extraJavaOptions").isDefined) { logWarning("spark.yarn.am.extraJavaOptions will not take effect in cluster mode") @@ -698,7 +698,7 @@ private[spark] class Client( } sparkConf.getOption("spark.yarn.am.extraLibraryPath").foreach { paths => - prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(paths))) + prefixEnv = Some(getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(paths)))) } } @@ -1106,10 +1106,10 @@ object Client extends Logging { env: HashMap[String, String], isAM: Boolean, extraClassPath: Option[String] = None): Unit = { - extraClassPath.foreach(addClasspathEntry(_, env)) - addClasspathEntry( - YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env - ) + extraClassPath.foreach { cp => + addClasspathEntry(getClusterPath(sparkConf, cp), env) + } + addClasspathEntry(YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env) if (isAM) { addClasspathEntry( @@ -1125,12 +1125,14 @@ object Client extends Logging { getUserClasspath(sparkConf) } userClassPath.foreach { x => - addFileToClasspath(x, null, env) + addFileToClasspath(sparkConf, x, null, env) } } - addFileToClasspath(new URI(sparkJar(sparkConf)), SPARK_JAR, env) + addFileToClasspath(sparkConf, new URI(sparkJar(sparkConf)), SPARK_JAR, env) populateHadoopClasspath(conf, env) - sys.env.get(ENV_DIST_CLASSPATH).foreach(addClasspathEntry(_, env)) + sys.env.get(ENV_DIST_CLASSPATH).foreach { cp => + addClasspathEntry(getClusterPath(sparkConf, cp), env) + } } /** @@ -1159,16 +1161,18 @@ object Client extends Logging { * * If not a "local:" file and no alternate name, the environment is not modified. * + * @parma conf Spark configuration. * @param uri URI to add to classpath (optional). * @param fileName Alternate name for the file (optional). * @param env Map holding the environment variables. */ private def addFileToClasspath( + conf: SparkConf, uri: URI, fileName: String, env: HashMap[String, String]): Unit = { if (uri != null && uri.getScheme == LOCAL_SCHEME) { - addClasspathEntry(uri.getPath, env) + addClasspathEntry(getClusterPath(conf, uri.getPath), env) } else if (fileName != null) { addClasspathEntry(buildPath( YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), fileName), env) @@ -1182,6 +1186,29 @@ object Client extends Logging { private def addClasspathEntry(path: String, env: HashMap[String, String]): Unit = YarnSparkHadoopUtil.addPathToEnvironment(env, Environment.CLASSPATH.name, path) + /** + * Returns the path to be sent to the NM for a path that is valid on the gateway. + * + * This method uses two configuration values: + * + * - spark.yarn.config.gatewayPath: a string that identifies a portion of the input path that may + * only be valid in the gateway node. + * - spark.yarn.config.replacementPath: a string with which to replace the gateway path. This may + * contain, for example, env variable references, which will be expanded by the NMs when + * starting containers. + * + * If either config is not available, the input path is returned. + */ + def getClusterPath(conf: SparkConf, path: String): String = { + val localPath = conf.get("spark.yarn.config.gatewayPath", null) + val clusterPath = conf.get("spark.yarn.config.replacementPath", null) + if (localPath != null && clusterPath != null) { + path.replace(localPath, clusterPath) + } else { + path + } + } + /** * Obtains token for the Hive metastore and adds them to the credentials. */ diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index b0937083bc536..78e27fb7f3337 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -146,7 +146,7 @@ class ExecutorRunnable( javaOpts ++= Utils.splitCommandString(opts).map(YarnSparkHadoopUtil.escapeForShell) } sys.props.get("spark.executor.extraLibraryPath").foreach { p => - prefixEnv = Some(Utils.libraryPathEnvPrefix(Seq(p))) + prefixEnv = Some(Client.getClusterPath(sparkConf, Utils.libraryPathEnvPrefix(Seq(p)))) } javaOpts += "-Djava.io.tmpdir=" + @@ -195,7 +195,7 @@ class ExecutorRunnable( val userClassPath = Client.getUserClasspath(sparkConf).flatMap { uri => val absPath = if (new File(uri.getPath()).isAbsolute()) { - uri.getPath() + Client.getClusterPath(sparkConf, uri.getPath()) } else { Client.buildPath(Environment.PWD.$(), uri.getPath()) } diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index 4ec976aa31387..837f8d3fa55a7 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -151,6 +151,25 @@ class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { } } + test("Cluster path translation") { + val conf = new Configuration() + val sparkConf = new SparkConf() + .set(Client.CONF_SPARK_JAR, "local:/localPath/spark.jar") + .set("spark.yarn.config.gatewayPath", "/localPath") + .set("spark.yarn.config.replacementPath", "/remotePath") + + Client.getClusterPath(sparkConf, "/localPath") should be ("/remotePath") + Client.getClusterPath(sparkConf, "/localPath/1:/localPath/2") should be ( + "/remotePath/1:/remotePath/2") + + val env = new MutableHashMap[String, String]() + Client.populateClasspath(null, conf, sparkConf, env, false, + extraClassPath = Some("/localPath/my1.jar")) + val cp = classpath(env) + cp should contain ("/remotePath/spark.jar") + cp should contain ("/remotePath/my1.jar") + } + object Fixtures { val knownDefYarnAppCP: Seq[String] = From 41afa16500e682475eaa80e31c0434b7ab66abcb Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Fri, 26 Jun 2015 08:12:22 -0700 Subject: [PATCH 020/122] [SPARK-8652] [PYSPARK] Check return value for all uses of doctest.testmod() This patch addresses a critical issue in the PySpark tests: Several of our Python modules' `__main__` methods call `doctest.testmod()` in order to run doctests but forget to check and handle its return value. As a result, some PySpark test failures can go unnoticed because they will not fail the build. Fortunately, there was only one test failure which was masked by this bug: a `pyspark.profiler` doctest was failing due to changes in RDD pipelining. Author: Josh Rosen Closes #7032 from JoshRosen/testmod-fix and squashes the following commits: 60dbdc0 [Josh Rosen] Account for int vs. long formatting change in Python 3 8b8d80a [Josh Rosen] Fix failing test. e6423f9 [Josh Rosen] Check return code for all uses of doctest.testmod(). --- dev/merge_spark_pr.py | 4 +++- python/pyspark/accumulators.py | 4 +++- python/pyspark/broadcast.py | 4 +++- python/pyspark/heapq3.py | 5 +++-- python/pyspark/profiler.py | 8 ++++++-- python/pyspark/serializers.py | 8 +++++--- python/pyspark/shuffle.py | 4 +++- python/pyspark/streaming/util.py | 4 +++- 8 files changed, 29 insertions(+), 12 deletions(-) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index cd83b352c1bfb..cf827ce89b857 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -431,6 +431,8 @@ def main(): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) main() diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index adca90ddaf397..6ef8cf53cc747 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -264,4 +264,6 @@ def _start_update_server(): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 3de4615428bb6..663c9abe0881e 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -115,4 +115,6 @@ def __reduce__(self): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/heapq3.py b/python/pyspark/heapq3.py index 4ef2afe03544f..b27e91a4cc251 100644 --- a/python/pyspark/heapq3.py +++ b/python/pyspark/heapq3.py @@ -883,6 +883,7 @@ def nlargest(n, iterable, key=None): if __name__ == "__main__": - import doctest - print(doctest.testmod()) + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py index d18daaabfcb3c..44d17bd629473 100644 --- a/python/pyspark/profiler.py +++ b/python/pyspark/profiler.py @@ -90,9 +90,11 @@ class Profiler(object): >>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler) >>> sc.parallelize(range(1000)).map(lambda x: 2 * x).take(10) [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] + >>> sc.parallelize(range(1000)).count() + 1000 >>> sc.show_profiles() My custom profiles for RDD:1 - My custom profiles for RDD:2 + My custom profiles for RDD:3 >>> sc.stop() """ @@ -169,4 +171,6 @@ def stats(self): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 7f9d0a338d31e..411b4dbf481f1 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -44,8 +44,8 @@ >>> rdd.glom().collect() [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] ->>> rdd._jrdd.count() -8L +>>> int(rdd._jrdd.count()) +8 >>> sc.stop() """ @@ -556,4 +556,6 @@ def write_with_length(obj, stream): if __name__ == '__main__': import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 67752c0d150b9..8fb71bac64a5e 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -838,4 +838,6 @@ def load_partition(j): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index 34291f30a5652..a9bfec2aab8fc 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -125,4 +125,6 @@ def rddToFileName(prefix, suffix, timestamp): if __name__ == "__main__": import doctest - doctest.testmod() + (failure_count, test_count) = doctest.testmod() + if failure_count: + exit(-1) From a56516fc9280724db8fdef8e7d109ed7e28e427d Mon Sep 17 00:00:00 2001 From: cafreeman Date: Fri, 26 Jun 2015 10:07:35 -0700 Subject: [PATCH 021/122] [SPARK-8662] SparkR Update SparkSQL Test Test `infer_type` using a more fine-grained approach rather than comparing environments. Since `all.equal`'s behavior has changed in R 3.2, the test became unpassable. JIRA here: https://issues.apache.org/jira/browse/SPARK-8662 Author: cafreeman Closes #7045 from cafreeman/R32_Test and squashes the following commits: b97cc52 [cafreeman] Add `checkStructField` utility 3381e5c [cafreeman] Update SparkSQL Test (cherry picked from commit 78b31a2a630c2178987322d0221aeea183ec565f) Signed-off-by: Shivaram Venkataraman --- R/pkg/inst/tests/test_sparkSQL.R | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 417153dc0985c..6a08f894313c4 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -19,6 +19,14 @@ library(testthat) context("SparkSQL functions") +# Utility function for easily checking the values of a StructField +checkStructField <- function(actual, expectedName, expectedType, expectedNullable) { + expect_equal(class(actual), "structField") + expect_equal(actual$name(), expectedName) + expect_equal(actual$dataType.toString(), expectedType) + expect_equal(actual$nullable(), expectedNullable) +} + # Tests for SparkSQL functions in SparkR sc <- sparkR.init() @@ -52,9 +60,10 @@ test_that("infer types", { list(type = 'array', elementType = "integer", containsNull = TRUE)) expect_equal(infer_type(list(1L, 2L)), list(type = 'array', elementType = "integer", containsNull = TRUE)) - expect_equal(infer_type(list(a = 1L, b = "2")), - structType(structField(x = "a", type = "integer", nullable = TRUE), - structField(x = "b", type = "string", nullable = TRUE))) + testStruct <- infer_type(list(a = 1L, b = "2")) + expect_true(class(testStruct) == "structType") + checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) + checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) e <- new.env() assign("a", 1L, envir = e) expect_equal(infer_type(e), From 9d11817765e2817b11b73c61bae3b32c9f119cfd Mon Sep 17 00:00:00 2001 From: cafreeman Date: Fri, 26 Jun 2015 17:06:02 -0700 Subject: [PATCH 022/122] [SPARK-8607] SparkR -- jars not being added to application classpath correctly Add `getStaticClass` method in SparkR's `RBackendHandler` This is a fix for the problem referenced in [SPARK-5185](https://issues.apache.org/jira/browse/SPARK-5185). cc shivaram Author: cafreeman Closes #7001 from cafreeman/branch-1.4 and squashes the following commits: 8f81194 [cafreeman] Add missing license 31aedcf [cafreeman] Refactor test to call an external R script 2c22073 [cafreeman] Merge branch 'branch-1.4' of github.com:apache/spark into branch-1.4 0bea809 [cafreeman] Fixed relative path issue and added smaller JAR ee25e60 [cafreeman] Merge branch 'branch-1.4' of github.com:apache/spark into branch-1.4 9a5c362 [cafreeman] test for including JAR when launching sparkContext 9101223 [cafreeman] Merge branch 'branch-1.4' of github.com:apache/spark into branch-1.4 5a80844 [cafreeman] Fix style nits 7c6bd0c [cafreeman] [SPARK-8607] SparkR (cherry picked from commit 2579948bf5d89ac2d822ace605a6a4afce5258d6) Signed-off-by: Shivaram Venkataraman --- .../test_support/sparktestjar_2.10-1.0.jar | Bin 0 -> 2886 bytes R/pkg/inst/tests/jarTest.R | 32 +++++++++++++++ R/pkg/inst/tests/test_includeJAR.R | 37 ++++++++++++++++++ .../apache/spark/api/r/RBackendHandler.scala | 17 +++++++- 4 files changed, 85 insertions(+), 1 deletion(-) create mode 100644 R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar create mode 100644 R/pkg/inst/tests/jarTest.R create mode 100644 R/pkg/inst/tests/test_includeJAR.R diff --git a/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar b/R/pkg/inst/test_support/sparktestjar_2.10-1.0.jar new file mode 100644 index 0000000000000000000000000000000000000000..1d5c2af631aa3ae88aa7836e8db598e59cbcf1b7 GIT binary patch literal 2886 zcmaJ@2T)UK7Y!kyL@*d49Sj{JRS@YAsv!nq=*S|~fYJj-g+=LYfzbORNV8FN{bdCk zg0zIbAfgZyq{;FFilHpKi8@1Yb?=)u^SycZJMX@^=brE2Fzg^WfQyR@ut^-V0I&oc z00Lmm?NG{SYYSB@${KB9ZfmE4wbuIiGP0M_cNecVtU;Sur6_lz zsaWb^v=SR+A;CLuy3$1vE+`}5lL)YnkQMN#MIA*nA&%;3C~FAI<>zOe_O2pl73Db< z*~X|~OI8rlVFtFrMKPnCLMpTwAOMHqFeX~AEe^t??EK`oE#4vGUh9FCIkePXoe5stL%}& zNVAhqLvP+N2Ki!4axnH{pCBmT>;l@Gm+pFq9mvWPc4#D^ejDn^&%HvUK^OB z)EgN^0iUSck}R0#%xqa6FDyH=u4lyK#mnchVHF8GuTYWvvwq8}Wg!P7M*Y|;8%rrT zSMkG(9`ZZmI9>ms=PmBAWd9<1I7oL}szYkqA=5BaS>~7Fg4Fu;I@PigC*Z->SEaFo z)mA(U+Et_0mnO!HE930f3)`UwV_k-Irmgy4i(gJx=yLphKC9Z|?D+e)*g9|J@Rsv7 zk~1Pi$~vua^rxVv3bbNK;S<5F#?d~J`xnxM^8*4`fzJ2NyI^8D1%q_oI-HdKqvn++ z@j;zGC;;QDtGL>2zA`gW*&5@)PLxO097`UtpP@;T8Uuo*gl5_07hrMQ`oFAcb9Zpv zU;eNo|IP=^hg}#?g_&HGTvq1g0WHlng}gi$D0%CHfbvHm+#?yqW(9U9_w{s}bsD60 zJbKlZTGQ2eSt2gWZfi(kU)^7K5xfcIeb*Fv%>>#q_2YJy&1isVF^WX@#n6hj6z{`$ z%jC^tPLH<4yHC$bm@gpo;104PgXP`zRac*-ITr)B!AzcE1SyllYyf~A@C(nrPYj{& z5kuw+Gpg{PnPDKR7ye#IBnz)F3M|+5yB;wmr()vg;(b}QHGS>#( z3vS3(@8`Xdz_ZHZ5;L)1`Z2Y^eK=$BWymyG{N|B)PX$>RK0=zub0-sStx+Pv>1NSF zCCw=kgsKW_ELQGn0T9Y&+I!bpHCO`y5xKugYBh#)_91|y|bv3&bbNn zhl?VRCe)^$J}q+=Qe+qHvz>9J*ohAqsNP^``th`sTtvmSm>alzorsQR37?+mp&3oj zY& z-k5K3sz5^m;*OY4n|5OT1|(XAJm)oLyn_?h$3hGPO(hQPO*jb4N2-&N?h;NAI z=|S}KvIr0K3iAsJ{7SfRa*uVZF+Ab#XH`{T@KaGilFF5MkgL^;S{WGyZiRC-RlWyO z=G$vROpZ`kR$=NeU#AyYCMvqt$Fy_Aie4LY#jBp$cw7PqDEPtTh@#xpf`=2-BF zNjKdZWD>bJ*>Rv?-)y70+L^rhn@RkT%g8=`TM8KUP}ueLr+ORi+?~H)>@&F$&>PUP zt`~RVo|c@#Qa{x=v06Gob2O>Tdzy1dQPO$o<5>ff*2{Mcw5%I^VYT-&l@prsYNVsX zI3RXuY@8*{uGUH>j~-8trp2Gz=N0(6E0(-8o*pi$#M6@|wdpI|RQ<=js%)_?XnEIF z5HtmslKq+Jh;CpZU$T;;eF? zxpCQ&39>;eWaswZk0LCAGCDUJ2Jb_)35SFHJ|10PT-J6CIOVOCTx+gwW=oG1vhuKp z`HXD|GMhCjwbXK&p$fcvRN>oZ9r~|keG!N7%#%Sd6ki7+&^etwJrwvqoSM$1N*KS2 zyVq^A?D&|c{p8Sqd+4LiG{mX-(qIxc^8cb6)3H!PI@>lVO62We8|R@1UGFJIYdx3m zYCPWArgH<7Usxz`l+icezS#d7@teFp-%!*v)^sWYi7;6&nhcHTdKhm|MDoF3kZT-3_HO6kWVS)v`toS%e?3*2|~%cZro` z)VXA?Azv-4#lxOOl5G8Ij-%eTgldL}ZV-}0)X%oM3#AHbR)oLKXTqvz+lZ{3N@(p3yvbfn5G-99m*{uhm!9X}zvQGyzicGsj=}Bjwy);9H?Ff==SdETklfOmO6)|W6!QWPkfel(kTtipW5d1 z@LrE>9H@+1qE^&56Lbj5@sZ7BeuDcHvXXB&dthawfgqWa`03h`<{Wtd!F$TDA5oZ- z8_l=4sUaviPan9nu-=YOd855*67tq9$@oN`%9_5>v?#yMX4~T`V}5wj(|1=tWTzjQ ztwaJc#Q{sA-k9desi&iQNycsy5F9t4QbsJloP&HNu~-XC-^XN1_zY>^YX(ysKQo05 z2nXz*AgmsSX{+|ek4zR0!v=%^e(ZO4QEu<@@4q%N{m*U;GL~O0(^ogNw`kS_k?Dta zW1F#L-O1vPn4f3;b5^lqo}IgKkRgBn0{JRztSHP`W1T|8E(Bv01>TGDJ(>I#jkQzE i$=!{^3>(Q>(;l=hbBx1)IhY$b8P_ + val clsContext = Class.forName(objId, true, Thread.currentThread().getContextClassLoader) + clsContext + } + } + def handleMethodCall( isStatic: Boolean, objId: String, @@ -98,7 +113,7 @@ private[r] class RBackendHandler(server: RBackend) var obj: Object = null try { val cls = if (isStatic) { - Class.forName(objId) + getStaticClass(objId) } else { JVMObjectTracker.get(objId) match { case None => throw new IllegalArgumentException("Object not found " + objId) From b5a6663da28198c905df27534cd123360a9bbef1 Mon Sep 17 00:00:00 2001 From: Rosstin Date: Sat, 27 Jun 2015 08:47:00 +0300 Subject: [PATCH 023/122] [SPARK-8639] [DOCS] Fixed Minor Typos in Documentation Ticket: [SPARK-8639](https://issues.apache.org/jira/browse/SPARK-8639) fixed minor typos in docs/README.md and docs/api.md Author: Rosstin Closes #7046 from Rosstin/SPARK-8639 and squashes the following commits: 6c18058 [Rosstin] fixed minor typos in docs/README.md and docs/api.md --- docs/README.md | 2 +- docs/api.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/README.md b/docs/README.md index 5852f972a051d..d7652e921f7df 100644 --- a/docs/README.md +++ b/docs/README.md @@ -28,7 +28,7 @@ in some cases: $ sudo gem install jekyll $ sudo gem install jekyll-redirect-from -Execute `jekyll` from the `docs/` directory. Compiling the site with Jekyll will create a directory +Execute `jekyll build` from the `docs/` directory to compile the site. Compiling the site with Jekyll will create a directory called `_site` containing index.html as well as the rest of the compiled files. You can modify the default Jekyll build as follows: diff --git a/docs/api.md b/docs/api.md index 45df77ac05f78..ae7d51c2aefbf 100644 --- a/docs/api.md +++ b/docs/api.md @@ -3,7 +3,7 @@ layout: global title: Spark API Documentation --- -Here you can API docs for Spark and its submodules. +Here you can read API docs for Spark and its submodules. - [Spark Scala API (Scaladoc)](api/scala/index.html) - [Spark Java API (Javadoc)](api/java/index.html) From d48e78934a346f023bd5cf44a34320f4d5a88e12 Mon Sep 17 00:00:00 2001 From: Neelesh Srinivas Salian Date: Sat, 27 Jun 2015 09:07:10 +0300 Subject: [PATCH 024/122] [SPARK-3629] [YARN] [DOCS]: Improvement of the "Running Spark on YARN" document As per the description in the JIRA, I moved the contents of the page and added a few additional content. Author: Neelesh Srinivas Salian Closes #6924 from nssalian/SPARK-3629 and squashes the following commits: 944b7a0 [Neelesh Srinivas Salian] Changed the lines about deploy-mode and added backticks to all parameters 40dbc0b [Neelesh Srinivas Salian] Changed dfs to HDFS, deploy-mode in backticks and updated the master yarn line 9cbc072 [Neelesh Srinivas Salian] Updated a few lines in the Launching Spark on YARN Section 8e8db7f [Neelesh Srinivas Salian] Removed the changes in this commit to help clearly distinguish movement from update 151c298 [Neelesh Srinivas Salian] SPARK-3629: Improvement of the Spark on YARN document --- docs/running-on-yarn.md | 164 ++++++++++++++++++++-------------------- 1 file changed, 82 insertions(+), 82 deletions(-) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 3f8a093bbe957..de22ab557cacf 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -7,6 +7,51 @@ Support for running on [YARN (Hadoop NextGen)](http://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/YARN.html) was added to Spark in version 0.6.0, and improved in subsequent releases. +# Launching Spark on YARN + +Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. +These configs are used to write to HDFS and connect to the YARN ResourceManager. The +configuration contained in this directory will be distributed to the YARN cluster so that all +containers used by the application use the same configuration. If the configuration references +Java system properties or environment variables not managed by YARN, they should also be set in the +Spark application's configuration (driver, executors, and the AM when running in client mode). + +There are two deploy modes that can be used to launch Spark applications on YARN. In `yarn-cluster` mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In `yarn-client` mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. + +Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the `--master` parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the `--master` parameter is `yarn-client` or `yarn-cluster`. +To launch a Spark application in `yarn-cluster` mode: + + `$ ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options]` + +For example: + + $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ + --master yarn-cluster \ + --num-executors 3 \ + --driver-memory 4g \ + --executor-memory 2g \ + --executor-cores 1 \ + --queue thequeue \ + lib/spark-examples*.jar \ + 10 + +The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. + +To launch a Spark application in `yarn-client` mode, do the same, but replace `yarn-cluster` with `yarn-client`. To run spark-shell: + + $ ./bin/spark-shell --master yarn-client + +## Adding Other JARs + +In `yarn-cluster` mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. + + $ ./bin/spark-submit --class my.main.Class \ + --master yarn-cluster \ + --jars my-other-jar.jar,my-other-other-jar.jar + my-main-jar.jar + app_arg1 app_arg2 + + # Preparations Running Spark-on-YARN requires a binary distribution of Spark which is built with YARN support. @@ -17,6 +62,38 @@ To build Spark yourself, refer to [Building Spark](building-spark.html). Most of the configs are the same for Spark on YARN as for other deployment modes. See the [configuration page](configuration.html) for more information on those. These are configs that are specific to Spark on YARN. +# Debugging your Application + +In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. + + yarn logs -applicationId + +will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). + +When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. + +To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a +large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` +on the nodes on which containers are launched. This directory contains the launch script, JARs, and +all environment variables used for launching each container. This process is useful for debugging +classpath problems in particular. (Note that enabling this requires admin privileges on cluster +settings and a restart of all node managers. Thus, this is not applicable to hosted clusters). + +To use a custom log4j configuration for the application master or executors, there are two options: + +- upload a custom `log4j.properties` using `spark-submit`, by adding it to the `--files` list of files + to be uploaded with the application. +- add `-Dlog4j.configuration=` to `spark.driver.extraJavaOptions` + (for the driver) or `spark.executor.extraJavaOptions` (for executors). Note that if using a file, + the `file:` protocol should be explicitly provided, and the file needs to exist locally on all + the nodes. + +Note that for the first option, both executors and the application master will share the same +log4j configuration, which may cause issues when they run on the same node (e.g. trying to write +to the same log file). + +If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. + #### Spark Properties @@ -50,8 +127,8 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -189,8 +266,8 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -206,7 +283,7 @@ Most of the configs are the same for Spark on YARN as for other deployment modes @@ -286,83 +363,6 @@ Most of the configs are the same for Spark on YARN as for other deployment modes
spark.yarn.am.waitTime 100s - In yarn-cluster mode, time for the application master to wait for the - SparkContext to be initialized. In yarn-client mode, time for the application master to wait + In `yarn-cluster` mode, time for the application master to wait for the + SparkContext to be initialized. In `yarn-client` mode, time for the application master to wait for the driver to connect to it.
Add the environment variable specified by EnvironmentVariableName to the Application Master process launched on YARN. The user can specify multiple of - these and to set multiple environment variables. In yarn-cluster mode this controls - the environment of the SPARK driver and in yarn-client mode it only controls + these and to set multiple environment variables. In `yarn-cluster` mode this controls + the environment of the SPARK driver and in `yarn-client` mode it only controls the environment of the executor launcher.
(none) A string of extra JVM options to pass to the YARN Application Master in client mode. - In cluster mode, use spark.driver.extraJavaOptions instead. + In cluster mode, use `spark.driver.extraJavaOptions` instead.
-# Launching Spark on YARN - -Ensure that `HADOOP_CONF_DIR` or `YARN_CONF_DIR` points to the directory which contains the (client side) configuration files for the Hadoop cluster. -These configs are used to write to the dfs and connect to the YARN ResourceManager. The -configuration contained in this directory will be distributed to the YARN cluster so that all -containers used by the application use the same configuration. If the configuration references -Java system properties or environment variables not managed by YARN, they should also be set in the -Spark application's configuration (driver, executors, and the AM when running in client mode). - -There are two deploy modes that can be used to launch Spark applications on YARN. In yarn-cluster mode, the Spark driver runs inside an application master process which is managed by YARN on the cluster, and the client can go away after initiating the application. In yarn-client mode, the driver runs in the client process, and the application master is only used for requesting resources from YARN. - -Unlike in Spark standalone and Mesos mode, in which the master's address is specified in the "master" parameter, in YARN mode the ResourceManager's address is picked up from the Hadoop configuration. Thus, the master parameter is simply "yarn-client" or "yarn-cluster". - -To launch a Spark application in yarn-cluster mode: - - ./bin/spark-submit --class path.to.your.Class --master yarn-cluster [options] [app options] - -For example: - - $ ./bin/spark-submit --class org.apache.spark.examples.SparkPi \ - --master yarn-cluster \ - --num-executors 3 \ - --driver-memory 4g \ - --executor-memory 2g \ - --executor-cores 1 \ - --queue thequeue \ - lib/spark-examples*.jar \ - 10 - -The above starts a YARN client program which starts the default Application Master. Then SparkPi will be run as a child thread of Application Master. The client will periodically poll the Application Master for status updates and display them in the console. The client will exit once your application has finished running. Refer to the "Debugging your Application" section below for how to see driver and executor logs. - -To launch a Spark application in yarn-client mode, do the same, but replace "yarn-cluster" with "yarn-client". To run spark-shell: - - $ ./bin/spark-shell --master yarn-client - -## Adding Other JARs - -In yarn-cluster mode, the driver runs on a different machine than the client, so `SparkContext.addJar` won't work out of the box with files that are local to the client. To make files on the client available to `SparkContext.addJar`, include them with the `--jars` option in the launch command. - - $ ./bin/spark-submit --class my.main.Class \ - --master yarn-cluster \ - --jars my-other-jar.jar,my-other-other-jar.jar - my-main-jar.jar - app_arg1 app_arg2 - -# Debugging your Application - -In YARN terminology, executors and application masters run inside "containers". YARN has two modes for handling container logs after an application has completed. If log aggregation is turned on (with the `yarn.log-aggregation-enable` config), container logs are copied to HDFS and deleted on the local machine. These logs can be viewed from anywhere on the cluster with the "yarn logs" command. - - yarn logs -applicationId - -will print out the contents of all log files from all containers from the given application. You can also view the container log files directly in HDFS using the HDFS shell or API. The directory where they are located can be found by looking at your YARN configs (`yarn.nodemanager.remote-app-log-dir` and `yarn.nodemanager.remote-app-log-dir-suffix`). - -When log aggregation isn't turned on, logs are retained locally on each machine under `YARN_APP_LOGS_DIR`, which is usually configured to `/tmp/logs` or `$HADOOP_HOME/logs/userlogs` depending on the Hadoop version and installation. Viewing logs for a container requires going to the host that contains them and looking in this directory. Subdirectories organize log files by application ID and container ID. - -To review per-container launch environment, increase `yarn.nodemanager.delete.debug-delay-sec` to a -large value (e.g. 36000), and then access the application cache through `yarn.nodemanager.local-dirs` -on the nodes on which containers are launched. This directory contains the launch script, JARs, and -all environment variables used for launching each container. This process is useful for debugging -classpath problems in particular. (Note that enabling this requires admin privileges on cluster -settings and a restart of all node managers. Thus, this is not applicable to hosted clusters). - -To use a custom log4j configuration for the application master or executors, there are two options: - -- upload a custom `log4j.properties` using `spark-submit`, by adding it to the `--files` list of files - to be uploaded with the application. -- add `-Dlog4j.configuration=` to `spark.driver.extraJavaOptions` - (for the driver) or `spark.executor.extraJavaOptions` (for executors). Note that if using a file, - the `file:` protocol should be explicitly provided, and the file needs to exist locally on all - the nodes. - -Note that for the first option, both executors and the application master will share the same -log4j configuration, which may cause issues when they run on the same node (e.g. trying to write -to the same log file). - -If you need a reference to the proper location to put log files in the YARN so that YARN can properly display and aggregate them, use `spark.yarn.app.container.log.dir` in your log4j.properties. For example, `log4j.appender.file_appender.File=${spark.yarn.app.container.log.dir}/spark.log`. For streaming application, configuring `RollingFileAppender` and setting file location to YARN's log directory will avoid disk overflow caused by large log file, and logs can be accessed using YARN's log utility. - # Important notes - Whether core requests are honored in scheduling decisions depends on which scheduler is in use and how it is configured. From 4153776fd840ae075e6bb608f054091b6d3ec0c4 Mon Sep 17 00:00:00 2001 From: Sandy Ryza Date: Sat, 27 Jun 2015 14:33:31 -0700 Subject: [PATCH 025/122] [SPARK-8623] Hadoop RDDs fail to properly serialize configuration Author: Sandy Ryza Closes #7050 from sryza/sandy-spark-8623 and squashes the following commits: 58a8079 [Sandy Ryza] SPARK-8623. Hadoop RDDs fail to properly serialize configuration --- .../scala/org/apache/spark/serializer/KryoSerializer.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index cd8a82347a1e9..ed35cffe968f8 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -36,7 +36,7 @@ import org.apache.spark.network.nio.{GetBlock, GotBlock, PutBlock} import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ -import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf} import org.apache.spark.util.collection.CompactBuffer /** @@ -94,8 +94,10 @@ class KryoSerializer(conf: SparkConf) // For results returned by asJavaIterable. See JavaIterableWrapperSerializer. kryo.register(JavaIterableWrapperSerializer.wrapperClass, new JavaIterableWrapperSerializer) - // Allow sending SerializableWritable + // Allow sending classes with custom Java serializers kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) + kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer()) + kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer()) kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) From 0b5abbf5f96a5f6bfd15a65e8788cf3fa96fe54c Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 27 Jun 2015 14:40:45 -0700 Subject: [PATCH 026/122] [SPARK-8606] Prevent exceptions in RDD.getPreferredLocations() from crashing DAGScheduler If `RDD.getPreferredLocations()` throws an exception it may crash the DAGScheduler and SparkContext. This patch addresses this by adding a try-catch block. Author: Josh Rosen Closes #7023 from JoshRosen/SPARK-8606 and squashes the following commits: 770b169 [Josh Rosen] Fix getPreferredLocations() DAGScheduler crash with try block. 44a9b55 [Josh Rosen] Add test of a buggy getPartitions() method 19aa9f7 [Josh Rosen] Add (failing) regression test for getPreferredLocations() DAGScheduler crash --- .../apache/spark/scheduler/DAGScheduler.scala | 37 +++++++++++-------- .../spark/scheduler/DAGSchedulerSuite.scala | 31 ++++++++++++++++ 2 files changed, 53 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b00a5fee09bf2..a7cf0c23d9613 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -907,22 +907,29 @@ class DAGScheduler( return } - val tasks: Seq[Task[_]] = stage match { - case stage: ShuffleMapStage => - partitionsToCompute.map { id => - val locs = getPreferredLocs(stage.rdd, id) - val part = stage.rdd.partitions(id) - new ShuffleMapTask(stage.id, taskBinary, part, locs) - } + val tasks: Seq[Task[_]] = try { + stage match { + case stage: ShuffleMapStage => + partitionsToCompute.map { id => + val locs = getPreferredLocs(stage.rdd, id) + val part = stage.rdd.partitions(id) + new ShuffleMapTask(stage.id, taskBinary, part, locs) + } - case stage: ResultStage => - val job = stage.resultOfJob.get - partitionsToCompute.map { id => - val p: Int = job.partitions(id) - val part = stage.rdd.partitions(p) - val locs = getPreferredLocs(stage.rdd, p) - new ResultTask(stage.id, taskBinary, part, locs, id) - } + case stage: ResultStage => + val job = stage.resultOfJob.get + partitionsToCompute.map { id => + val p: Int = job.partitions(id) + val part = stage.rdd.partitions(p) + val locs = getPreferredLocs(stage.rdd, p) + new ResultTask(stage.id, taskBinary, part, locs, id) + } + } + } catch { + case NonFatal(e) => + abortStage(stage, s"Task creation failed: $e\n${e.getStackTraceString}") + runningStages -= stage + return } if (tasks.size > 0) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 833b600746e90..6bc45f249f975 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -784,6 +784,37 @@ class DAGSchedulerSuite assert(sc.parallelize(1 to 10, 2).first() === 1) } + test("getPartitions exceptions should not crash DAGScheduler and SparkContext (SPARK-8606)") { + val e1 = intercept[DAGSchedulerSuiteDummyException] { + val rdd = new MyRDD(sc, 2, Nil) { + override def getPartitions: Array[Partition] = { + throw new DAGSchedulerSuiteDummyException + } + } + rdd.reduceByKey(_ + _, 1).count() + } + + // Make sure we can still run local commands as well as cluster commands. + assert(sc.parallelize(1 to 10, 2).count() === 10) + assert(sc.parallelize(1 to 10, 2).first() === 1) + } + + test("getPreferredLocations errors should not crash DAGScheduler and SparkContext (SPARK-8606)") { + val e1 = intercept[SparkException] { + val rdd = new MyRDD(sc, 2, Nil) { + override def getPreferredLocations(split: Partition): Seq[String] = { + throw new DAGSchedulerSuiteDummyException + } + } + rdd.count() + } + assert(e1.getMessage.contains(classOf[DAGSchedulerSuiteDummyException].getName)) + + // Make sure we can still run local commands as well as cluster commands. + assert(sc.parallelize(1 to 10, 2).count() === 10) + assert(sc.parallelize(1 to 10, 2).first() === 1) + } + test("accumulator not calculated for resubmitted result stage") { // just for register val accum = new Accumulator[Int](0, AccumulatorParam.IntAccumulatorParam) From 40648c56cdaa52058a4771082f8f44a2d8e5a1ec Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 27 Jun 2015 20:24:34 -0700 Subject: [PATCH 027/122] [SPARK-8583] [SPARK-5482] [BUILD] Refactor python/run-tests to integrate with dev/run-tests module system This patch refactors the `python/run-tests` script: - It's now written in Python instead of Bash. - The descriptions of the tests to run are now stored in `dev/run-tests`'s modules. This allows the pull request builder to skip Python tests suites that were not affected by the pull request's changes. For example, we can now skip the PySpark Streaming test cases when only SQL files are changed. - `python/run-tests` now supports command-line flags to make it easier to run individual test suites (this addresses SPARK-5482): ``` Usage: run-tests [options] Options: -h, --help show this help message and exit --python-executables=PYTHON_EXECUTABLES A comma-separated list of Python executables to test against (default: python2.6,python3.4,pypy) --modules=MODULES A comma-separated list of Python modules to test (default: pyspark-core,pyspark-ml,pyspark-mllib ,pyspark-sql,pyspark-streaming) ``` - `dev/run-tests` has been split into multiple files: the module definitions and test utility functions are now stored inside of a `dev/sparktestsupport` Python module, allowing them to be re-used from the Python test runner script. Author: Josh Rosen Closes #6967 from JoshRosen/run-tests-python-modules and squashes the following commits: f578d6d [Josh Rosen] Fix print for Python 2.x 8233d61 [Josh Rosen] Add python/run-tests.py to Python lint checks 34c98d2 [Josh Rosen] Fix universal_newlines for Python 3 8f65ed0 [Josh Rosen] Fix handling of module in python/run-tests 37aff00 [Josh Rosen] Python 3 fix 27a389f [Josh Rosen] Skip MLLib tests for PyPy c364ccf [Josh Rosen] Use which() to convert PYSPARK_PYTHON to an absolute path before shelling out to run tests 568a3fd [Josh Rosen] Fix hashbang 3b852ae [Josh Rosen] Fall back to PYSPARK_PYTHON when sys.executable is None (fixes a test) f53db55 [Josh Rosen] Remove python2 flag, since the test runner script also works fine under Python 3 9c80469 [Josh Rosen] Fix passing of PYSPARK_PYTHON d33e525 [Josh Rosen] Merge remote-tracking branch 'origin/master' into run-tests-python-modules 4f8902c [Josh Rosen] Python lint fixes. 8f3244c [Josh Rosen] Use universal_newlines to fix dev/run-tests doctest failures on Python 3. f542ac5 [Josh Rosen] Fix lint check for Python 3 fff4d09 [Josh Rosen] Add dev/sparktestsupport to pep8 checks 2efd594 [Josh Rosen] Update dev/run-tests to use new Python test runner flags b2ab027 [Josh Rosen] Add command-line options for running individual suites in python/run-tests caeb040 [Josh Rosen] Fixes to PySpark test module definitions d6a77d3 [Josh Rosen] Fix the tests of dev/run-tests def2d8a [Josh Rosen] Two minor fixes aec0b8f [Josh Rosen] Actually get the Kafka stuff to run properly 04015b9 [Josh Rosen] First attempt at getting PySpark Kafka test to work in new runner script 4c97136 [Josh Rosen] PYTHONPATH fixes dcc9c09 [Josh Rosen] Fix time division 32660fc [Josh Rosen] Initial cut at Python test runner refactoring 311c6a9 [Josh Rosen] Move shell utility functions to own module. 1bdeb87 [Josh Rosen] Move module definitions to separate file. --- dev/lint-python | 3 +- dev/run-tests.py | 435 ++++------------------------- dev/sparktestsupport/__init__.py | 21 ++ dev/sparktestsupport/modules.py | 385 +++++++++++++++++++++++++ dev/sparktestsupport/shellutils.py | 81 ++++++ python/pyspark/streaming/tests.py | 16 ++ python/pyspark/tests.py | 3 +- python/run-tests | 164 +---------- python/run-tests.py | 132 +++++++++ 9 files changed, 700 insertions(+), 540 deletions(-) create mode 100644 dev/sparktestsupport/__init__.py create mode 100644 dev/sparktestsupport/modules.py create mode 100644 dev/sparktestsupport/shellutils.py create mode 100755 python/run-tests.py diff --git a/dev/lint-python b/dev/lint-python index f50d149dc4d44..0c3586462cb37 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -19,7 +19,8 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" -PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/" +PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/ ./dev/sparktestsupport" +PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py" PYTHON_LINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/python-lint-report.txt" cd "$SPARK_ROOT_DIR" diff --git a/dev/run-tests.py b/dev/run-tests.py index e7c09b0f40cdc..c51b0d3010a0f 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -17,297 +17,23 @@ # limitations under the License. # +from __future__ import print_function import itertools import os import re import sys -import shutil import subprocess from collections import namedtuple -SPARK_HOME = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..") -USER_HOME = os.environ.get("HOME") - +from sparktestsupport import SPARK_HOME, USER_HOME +from sparktestsupport.shellutils import exit_from_command_with_retcode, run_cmd, rm_r, which +import sparktestsupport.modules as modules # ------------------------------------------------------------------------------------------------- -# Test module definitions and functions for traversing module dependency graph +# Functions for traversing module dependency graph # ------------------------------------------------------------------------------------------------- -all_modules = [] - - -class Module(object): - """ - A module is the basic abstraction in our test runner script. Each module consists of a set of - source files, a set of test commands, and a set of dependencies on other modules. We use modules - to define a dependency graph that lets determine which tests to run based on which files have - changed. - """ - - def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), - sbt_test_goals=(), should_run_python_tests=False, should_run_r_tests=False): - """ - Define a new module. - - :param name: A short module name, for display in logging and error messages. - :param dependencies: A set of dependencies for this module. This should only include direct - dependencies; transitive dependencies are resolved automatically. - :param source_file_regexes: a set of regexes that match source files belonging to this - module. These regexes are applied by attempting to match at the beginning of the - filename strings. - :param build_profile_flags: A set of profile flags that should be passed to Maven or SBT in - order to build and test this module (e.g. '-PprofileName'). - :param sbt_test_goals: A set of SBT test goals for testing this module. - :param should_run_python_tests: If true, changes in this module will trigger Python tests. - For now, this has the effect of causing _all_ Python tests to be run, although in the - future this should be changed to run only a subset of the Python tests that depend - on this module. - :param should_run_r_tests: If true, changes in this module will trigger all R tests. - """ - self.name = name - self.dependencies = dependencies - self.source_file_prefixes = source_file_regexes - self.sbt_test_goals = sbt_test_goals - self.build_profile_flags = build_profile_flags - self.should_run_python_tests = should_run_python_tests - self.should_run_r_tests = should_run_r_tests - - self.dependent_modules = set() - for dep in dependencies: - dep.dependent_modules.add(self) - all_modules.append(self) - - def contains_file(self, filename): - return any(re.match(p, filename) for p in self.source_file_prefixes) - - -sql = Module( - name="sql", - dependencies=[], - source_file_regexes=[ - "sql/(?!hive-thriftserver)", - "bin/spark-sql", - ], - build_profile_flags=[ - "-Phive", - ], - sbt_test_goals=[ - "catalyst/test", - "sql/test", - "hive/test", - ] -) - - -hive_thriftserver = Module( - name="hive-thriftserver", - dependencies=[sql], - source_file_regexes=[ - "sql/hive-thriftserver", - "sbin/start-thriftserver.sh", - ], - build_profile_flags=[ - "-Phive-thriftserver", - ], - sbt_test_goals=[ - "hive-thriftserver/test", - ] -) - - -graphx = Module( - name="graphx", - dependencies=[], - source_file_regexes=[ - "graphx/", - ], - sbt_test_goals=[ - "graphx/test" - ] -) - - -streaming = Module( - name="streaming", - dependencies=[], - source_file_regexes=[ - "streaming", - ], - sbt_test_goals=[ - "streaming/test", - ] -) - - -streaming_kinesis_asl = Module( - name="kinesis-asl", - dependencies=[streaming], - source_file_regexes=[ - "extras/kinesis-asl/", - ], - build_profile_flags=[ - "-Pkinesis-asl", - ], - sbt_test_goals=[ - "kinesis-asl/test", - ] -) - - -streaming_zeromq = Module( - name="streaming-zeromq", - dependencies=[streaming], - source_file_regexes=[ - "external/zeromq", - ], - sbt_test_goals=[ - "streaming-zeromq/test", - ] -) - - -streaming_twitter = Module( - name="streaming-twitter", - dependencies=[streaming], - source_file_regexes=[ - "external/twitter", - ], - sbt_test_goals=[ - "streaming-twitter/test", - ] -) - - -streaming_mqtt = Module( - name="streaming-mqtt", - dependencies=[streaming], - source_file_regexes=[ - "external/mqtt", - ], - sbt_test_goals=[ - "streaming-mqtt/test", - ] -) - - -streaming_kafka = Module( - name="streaming-kafka", - dependencies=[streaming], - source_file_regexes=[ - "external/kafka", - "external/kafka-assembly", - ], - sbt_test_goals=[ - "streaming-kafka/test", - ] -) - - -streaming_flume_sink = Module( - name="streaming-flume-sink", - dependencies=[streaming], - source_file_regexes=[ - "external/flume-sink", - ], - sbt_test_goals=[ - "streaming-flume-sink/test", - ] -) - - -streaming_flume = Module( - name="streaming_flume", - dependencies=[streaming], - source_file_regexes=[ - "external/flume", - ], - sbt_test_goals=[ - "streaming-flume/test", - ] -) - - -mllib = Module( - name="mllib", - dependencies=[streaming, sql], - source_file_regexes=[ - "data/mllib/", - "mllib/", - ], - sbt_test_goals=[ - "mllib/test", - ] -) - - -examples = Module( - name="examples", - dependencies=[graphx, mllib, streaming, sql], - source_file_regexes=[ - "examples/", - ], - sbt_test_goals=[ - "examples/test", - ] -) - - -pyspark = Module( - name="pyspark", - dependencies=[mllib, streaming, streaming_kafka, sql], - source_file_regexes=[ - "python/" - ], - should_run_python_tests=True -) - - -sparkr = Module( - name="sparkr", - dependencies=[sql, mllib], - source_file_regexes=[ - "R/", - ], - should_run_r_tests=True -) - - -docs = Module( - name="docs", - dependencies=[], - source_file_regexes=[ - "docs/", - ] -) - - -ec2 = Module( - name="ec2", - dependencies=[], - source_file_regexes=[ - "ec2/", - ] -) - - -# The root module is a dummy module which is used to run all of the tests. -# No other modules should directly depend on this module. -root = Module( - name="root", - dependencies=[], - source_file_regexes=[], - # In order to run all of the tests, enable every test profile: - build_profile_flags= - list(set(itertools.chain.from_iterable(m.build_profile_flags for m in all_modules))), - sbt_test_goals=[ - "test", - ], - should_run_python_tests=True, - should_run_r_tests=True -) - - def determine_modules_for_files(filenames): """ Given a list of filenames, return the set of modules that contain those files. @@ -315,19 +41,19 @@ def determine_modules_for_files(filenames): file to belong to the 'root' module. >>> sorted(x.name for x in determine_modules_for_files(["python/pyspark/a.py", "sql/test/foo"])) - ['pyspark', 'sql'] + ['pyspark-core', 'sql'] >>> [x.name for x in determine_modules_for_files(["file_not_matched_by_any_subproject"])] ['root'] """ changed_modules = set() for filename in filenames: matched_at_least_one_module = False - for module in all_modules: + for module in modules.all_modules: if module.contains_file(filename): changed_modules.add(module) matched_at_least_one_module = True if not matched_at_least_one_module: - changed_modules.add(root) + changed_modules.add(modules.root) return changed_modules @@ -352,7 +78,8 @@ def identify_changed_files_from_git_commits(patch_sha, target_branch=None, targe run_cmd(['git', 'fetch', 'origin', str(target_branch+':'+target_branch)]) else: diff_target = target_ref - raw_output = subprocess.check_output(['git', 'diff', '--name-only', patch_sha, diff_target]) + raw_output = subprocess.check_output(['git', 'diff', '--name-only', patch_sha, diff_target], + universal_newlines=True) # Remove any empty strings return [f for f in raw_output.split('\n') if f] @@ -362,18 +89,20 @@ def determine_modules_to_test(changed_modules): Given a set of modules that have changed, compute the transitive closure of those modules' dependent modules in order to determine the set of modules that should be tested. - >>> sorted(x.name for x in determine_modules_to_test([root])) + >>> sorted(x.name for x in determine_modules_to_test([modules.root])) ['root'] - >>> sorted(x.name for x in determine_modules_to_test([graphx])) + >>> sorted(x.name for x in determine_modules_to_test([modules.graphx])) ['examples', 'graphx'] - >>> sorted(x.name for x in determine_modules_to_test([sql])) - ['examples', 'hive-thriftserver', 'mllib', 'pyspark', 'sparkr', 'sql'] + >>> x = sorted(x.name for x in determine_modules_to_test([modules.sql])) + >>> x # doctest: +NORMALIZE_WHITESPACE + ['examples', 'hive-thriftserver', 'mllib', 'pyspark-core', 'pyspark-ml', \ + 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming', 'sparkr', 'sql'] """ # If we're going to have to run all of the tests, then we can just short-circuit # and return 'root'. No module depends on root, so if it appears then it will be # in changed_modules. - if root in changed_modules: - return [root] + if modules.root in changed_modules: + return [modules.root] modules_to_test = set() for module in changed_modules: modules_to_test = modules_to_test.union(determine_modules_to_test(module.dependent_modules)) @@ -398,60 +127,6 @@ def get_error_codes(err_code_file): ERROR_CODES = get_error_codes(os.path.join(SPARK_HOME, "dev/run-tests-codes.sh")) -def exit_from_command_with_retcode(cmd, retcode): - print "[error] running", ' '.join(cmd), "; received return code", retcode - sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) - - -def rm_r(path): - """Given an arbitrary path properly remove it with the correct python - construct if it exists - - from: http://stackoverflow.com/a/9559881""" - - if os.path.isdir(path): - shutil.rmtree(path) - elif os.path.exists(path): - os.remove(path) - - -def run_cmd(cmd): - """Given a command as a list of arguments will attempt to execute the - command from the determined SPARK_HOME directory and, on failure, print - an error message""" - - if not isinstance(cmd, list): - cmd = cmd.split() - try: - subprocess.check_call(cmd) - except subprocess.CalledProcessError as e: - exit_from_command_with_retcode(e.cmd, e.returncode) - - -def is_exe(path): - """Check if a given path is an executable file - - from: http://stackoverflow.com/a/377028""" - - return os.path.isfile(path) and os.access(path, os.X_OK) - - -def which(program): - """Find and return the given program by its absolute path or 'None' - - from: http://stackoverflow.com/a/377028""" - - fpath = os.path.split(program)[0] - - if fpath: - if is_exe(program): - return program - else: - for path in os.environ.get("PATH").split(os.pathsep): - path = path.strip('"') - exe_file = os.path.join(path, program) - if is_exe(exe_file): - return exe_file - return None - - def determine_java_executable(): """Will return the path of the java executable that will be used by Spark's tests or `None`""" @@ -476,7 +151,8 @@ def determine_java_version(java_exe): with accessors '.major', '.minor', '.patch', '.update'""" raw_output = subprocess.check_output([java_exe, "-version"], - stderr=subprocess.STDOUT) + stderr=subprocess.STDOUT, + universal_newlines=True) raw_output_lines = raw_output.split('\n') @@ -504,10 +180,10 @@ def set_title_and_block(title, err_block): os.environ["CURRENT_BLOCK"] = ERROR_CODES[err_block] line_str = '=' * 72 - print - print line_str - print title - print line_str + print('') + print(line_str) + print(title) + print(line_str) def run_apache_rat_checks(): @@ -534,8 +210,8 @@ def build_spark_documentation(): jekyll_bin = which("jekyll") if not jekyll_bin: - print "[error] Cannot find a version of `jekyll` on the system; please", - print "install one and retry to build documentation." + print("[error] Cannot find a version of `jekyll` on the system; please" + " install one and retry to build documentation.") sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) else: run_cmd([jekyll_bin, "build"]) @@ -571,7 +247,7 @@ def exec_sbt(sbt_args=()): echo_proc.wait() for line in iter(sbt_proc.stdout.readline, ''): if not sbt_output_filter.match(line): - print line, + print(line, end='') retcode = sbt_proc.wait() if retcode > 0: @@ -594,33 +270,33 @@ def get_hadoop_profiles(hadoop_version): if hadoop_version in sbt_maven_hadoop_profiles: return sbt_maven_hadoop_profiles[hadoop_version] else: - print "[error] Could not find", hadoop_version, "in the list. Valid options", - print "are", sbt_maven_hadoop_profiles.keys() + print("[error] Could not find", hadoop_version, "in the list. Valid options" + " are", sbt_maven_hadoop_profiles.keys()) sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) def build_spark_maven(hadoop_version): # Enable all of the profiles for the build: - build_profiles = get_hadoop_profiles(hadoop_version) + root.build_profile_flags + build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags mvn_goals = ["clean", "package", "-DskipTests"] profiles_and_goals = build_profiles + mvn_goals - print "[info] Building Spark (w/Hive 0.13.1) using Maven with these arguments:", - print " ".join(profiles_and_goals) + print("[info] Building Spark (w/Hive 0.13.1) using Maven with these arguments: " + " ".join(profiles_and_goals)) exec_maven(profiles_and_goals) def build_spark_sbt(hadoop_version): # Enable all of the profiles for the build: - build_profiles = get_hadoop_profiles(hadoop_version) + root.build_profile_flags + build_profiles = get_hadoop_profiles(hadoop_version) + modules.root.build_profile_flags sbt_goals = ["package", "assembly/assembly", "streaming-kafka-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals - print "[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments:", - print " ".join(profiles_and_goals) + print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: " + " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) @@ -648,8 +324,8 @@ def run_scala_tests_maven(test_profiles): mvn_test_goals = ["test", "--fail-at-end"] profiles_and_goals = test_profiles + mvn_test_goals - print "[info] Running Spark tests using Maven with these arguments:", - print " ".join(profiles_and_goals) + print("[info] Running Spark tests using Maven with these arguments: " + " ".join(profiles_and_goals)) exec_maven(profiles_and_goals) @@ -663,8 +339,8 @@ def run_scala_tests_sbt(test_modules, test_profiles): profiles_and_goals = test_profiles + list(sbt_test_goals) - print "[info] Running Spark tests using SBT with these arguments:", - print " ".join(profiles_and_goals) + print("[info] Running Spark tests using SBT with these arguments: " + " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) @@ -684,10 +360,13 @@ def run_scala_tests(build_tool, hadoop_version, test_modules): run_scala_tests_sbt(test_modules, test_profiles) -def run_python_tests(): +def run_python_tests(test_modules): set_title_and_block("Running PySpark tests", "BLOCK_PYSPARK_UNIT_TESTS") - run_cmd([os.path.join(SPARK_HOME, "python", "run-tests")]) + command = [os.path.join(SPARK_HOME, "python", "run-tests")] + if test_modules != [modules.root]: + command.append("--modules=%s" % ','.join(m.name for m in modules)) + run_cmd(command) def run_sparkr_tests(): @@ -697,14 +376,14 @@ def run_sparkr_tests(): run_cmd([os.path.join(SPARK_HOME, "R", "install-dev.sh")]) run_cmd([os.path.join(SPARK_HOME, "R", "run-tests.sh")]) else: - print "Ignoring SparkR tests as R was not found in PATH" + print("Ignoring SparkR tests as R was not found in PATH") def main(): # Ensure the user home directory (HOME) is valid and is an absolute directory if not USER_HOME or not os.path.isabs(USER_HOME): - print "[error] Cannot determine your home directory as an absolute path;", - print "ensure the $HOME environment variable is set properly." + print("[error] Cannot determine your home directory as an absolute path;" + " ensure the $HOME environment variable is set properly.") sys.exit(1) os.chdir(SPARK_HOME) @@ -718,14 +397,14 @@ def main(): java_exe = determine_java_executable() if not java_exe: - print "[error] Cannot find a version of `java` on the system; please", - print "install one and retry." + print("[error] Cannot find a version of `java` on the system; please" + " install one and retry.") sys.exit(2) java_version = determine_java_version(java_exe) if java_version.minor < 8: - print "[warn] Java 8 tests will not run because JDK version is < 1.8." + print("[warn] Java 8 tests will not run because JDK version is < 1.8.") if os.environ.get("AMPLAB_JENKINS"): # if we're on the Amplab Jenkins build servers setup variables @@ -741,8 +420,8 @@ def main(): hadoop_version = "hadoop2.3" test_env = "local" - print "[info] Using build tool", build_tool, "with Hadoop profile", hadoop_version, - print "under environment", test_env + print("[info] Using build tool", build_tool, "with Hadoop profile", hadoop_version, + "under environment", test_env) changed_modules = None changed_files = None @@ -751,8 +430,9 @@ def main(): changed_files = identify_changed_files_from_git_commits("HEAD", target_branch=target_branch) changed_modules = determine_modules_for_files(changed_files) if not changed_modules: - changed_modules = [root] - print "[info] Found the following changed modules:", ", ".join(x.name for x in changed_modules) + changed_modules = [modules.root] + print("[info] Found the following changed modules:", + ", ".join(x.name for x in changed_modules)) test_modules = determine_modules_to_test(changed_modules) @@ -779,8 +459,9 @@ def main(): # run the test suites run_scala_tests(build_tool, hadoop_version, test_modules) - if any(m.should_run_python_tests for m in test_modules): - run_python_tests() + modules_with_python_tests = [m for m in test_modules if m.python_test_goals] + if modules_with_python_tests: + run_python_tests(modules_with_python_tests) if any(m.should_run_r_tests for m in test_modules): run_sparkr_tests() diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py new file mode 100644 index 0000000000000..12696d98fb988 --- /dev/null +++ b/dev/sparktestsupport/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os + +SPARK_HOME = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../")) +USER_HOME = os.environ.get("HOME") diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py new file mode 100644 index 0000000000000..efe3a897e9c10 --- /dev/null +++ b/dev/sparktestsupport/modules.py @@ -0,0 +1,385 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import itertools +import re + +all_modules = [] + + +class Module(object): + """ + A module is the basic abstraction in our test runner script. Each module consists of a set of + source files, a set of test commands, and a set of dependencies on other modules. We use modules + to define a dependency graph that lets determine which tests to run based on which files have + changed. + """ + + def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), + sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), + should_run_r_tests=False): + """ + Define a new module. + + :param name: A short module name, for display in logging and error messages. + :param dependencies: A set of dependencies for this module. This should only include direct + dependencies; transitive dependencies are resolved automatically. + :param source_file_regexes: a set of regexes that match source files belonging to this + module. These regexes are applied by attempting to match at the beginning of the + filename strings. + :param build_profile_flags: A set of profile flags that should be passed to Maven or SBT in + order to build and test this module (e.g. '-PprofileName'). + :param sbt_test_goals: A set of SBT test goals for testing this module. + :param python_test_goals: A set of Python test goals for testing this module. + :param blacklisted_python_implementations: A set of Python implementations that are not + supported by this module's Python components. The values in this set should match + strings returned by Python's `platform.python_implementation()`. + :param should_run_r_tests: If true, changes in this module will trigger all R tests. + """ + self.name = name + self.dependencies = dependencies + self.source_file_prefixes = source_file_regexes + self.sbt_test_goals = sbt_test_goals + self.build_profile_flags = build_profile_flags + self.python_test_goals = python_test_goals + self.blacklisted_python_implementations = blacklisted_python_implementations + self.should_run_r_tests = should_run_r_tests + + self.dependent_modules = set() + for dep in dependencies: + dep.dependent_modules.add(self) + all_modules.append(self) + + def contains_file(self, filename): + return any(re.match(p, filename) for p in self.source_file_prefixes) + + +sql = Module( + name="sql", + dependencies=[], + source_file_regexes=[ + "sql/(?!hive-thriftserver)", + "bin/spark-sql", + ], + build_profile_flags=[ + "-Phive", + ], + sbt_test_goals=[ + "catalyst/test", + "sql/test", + "hive/test", + ] +) + + +hive_thriftserver = Module( + name="hive-thriftserver", + dependencies=[sql], + source_file_regexes=[ + "sql/hive-thriftserver", + "sbin/start-thriftserver.sh", + ], + build_profile_flags=[ + "-Phive-thriftserver", + ], + sbt_test_goals=[ + "hive-thriftserver/test", + ] +) + + +graphx = Module( + name="graphx", + dependencies=[], + source_file_regexes=[ + "graphx/", + ], + sbt_test_goals=[ + "graphx/test" + ] +) + + +streaming = Module( + name="streaming", + dependencies=[], + source_file_regexes=[ + "streaming", + ], + sbt_test_goals=[ + "streaming/test", + ] +) + + +streaming_kinesis_asl = Module( + name="kinesis-asl", + dependencies=[streaming], + source_file_regexes=[ + "extras/kinesis-asl/", + ], + build_profile_flags=[ + "-Pkinesis-asl", + ], + sbt_test_goals=[ + "kinesis-asl/test", + ] +) + + +streaming_zeromq = Module( + name="streaming-zeromq", + dependencies=[streaming], + source_file_regexes=[ + "external/zeromq", + ], + sbt_test_goals=[ + "streaming-zeromq/test", + ] +) + + +streaming_twitter = Module( + name="streaming-twitter", + dependencies=[streaming], + source_file_regexes=[ + "external/twitter", + ], + sbt_test_goals=[ + "streaming-twitter/test", + ] +) + + +streaming_mqtt = Module( + name="streaming-mqtt", + dependencies=[streaming], + source_file_regexes=[ + "external/mqtt", + ], + sbt_test_goals=[ + "streaming-mqtt/test", + ] +) + + +streaming_kafka = Module( + name="streaming-kafka", + dependencies=[streaming], + source_file_regexes=[ + "external/kafka", + "external/kafka-assembly", + ], + sbt_test_goals=[ + "streaming-kafka/test", + ] +) + + +streaming_flume_sink = Module( + name="streaming-flume-sink", + dependencies=[streaming], + source_file_regexes=[ + "external/flume-sink", + ], + sbt_test_goals=[ + "streaming-flume-sink/test", + ] +) + + +streaming_flume = Module( + name="streaming_flume", + dependencies=[streaming], + source_file_regexes=[ + "external/flume", + ], + sbt_test_goals=[ + "streaming-flume/test", + ] +) + + +mllib = Module( + name="mllib", + dependencies=[streaming, sql], + source_file_regexes=[ + "data/mllib/", + "mllib/", + ], + sbt_test_goals=[ + "mllib/test", + ] +) + + +examples = Module( + name="examples", + dependencies=[graphx, mllib, streaming, sql], + source_file_regexes=[ + "examples/", + ], + sbt_test_goals=[ + "examples/test", + ] +) + + +pyspark_core = Module( + name="pyspark-core", + dependencies=[mllib, streaming, streaming_kafka], + source_file_regexes=[ + "python/(?!pyspark/(ml|mllib|sql|streaming))" + ], + python_test_goals=[ + "pyspark.rdd", + "pyspark.context", + "pyspark.conf", + "pyspark.broadcast", + "pyspark.accumulators", + "pyspark.serializers", + "pyspark.profiler", + "pyspark.shuffle", + "pyspark.tests", + ] +) + + +pyspark_sql = Module( + name="pyspark-sql", + dependencies=[pyspark_core, sql], + source_file_regexes=[ + "python/pyspark/sql" + ], + python_test_goals=[ + "pyspark.sql.types", + "pyspark.sql.context", + "pyspark.sql.column", + "pyspark.sql.dataframe", + "pyspark.sql.group", + "pyspark.sql.functions", + "pyspark.sql.readwriter", + "pyspark.sql.window", + "pyspark.sql.tests", + ] +) + + +pyspark_streaming = Module( + name="pyspark-streaming", + dependencies=[pyspark_core, streaming, streaming_kafka], + source_file_regexes=[ + "python/pyspark/streaming" + ], + python_test_goals=[ + "pyspark.streaming.util", + "pyspark.streaming.tests", + ] +) + + +pyspark_mllib = Module( + name="pyspark-mllib", + dependencies=[pyspark_core, pyspark_streaming, pyspark_sql, mllib], + source_file_regexes=[ + "python/pyspark/mllib" + ], + python_test_goals=[ + "pyspark.mllib.classification", + "pyspark.mllib.clustering", + "pyspark.mllib.evaluation", + "pyspark.mllib.feature", + "pyspark.mllib.fpm", + "pyspark.mllib.linalg", + "pyspark.mllib.random", + "pyspark.mllib.recommendation", + "pyspark.mllib.regression", + "pyspark.mllib.stat._statistics", + "pyspark.mllib.stat.KernelDensity", + "pyspark.mllib.tree", + "pyspark.mllib.util", + "pyspark.mllib.tests", + ], + blacklisted_python_implementations=[ + "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there + ] +) + + +pyspark_ml = Module( + name="pyspark-ml", + dependencies=[pyspark_core, pyspark_mllib], + source_file_regexes=[ + "python/pyspark/ml/" + ], + python_test_goals=[ + "pyspark.ml.feature", + "pyspark.ml.classification", + "pyspark.ml.recommendation", + "pyspark.ml.regression", + "pyspark.ml.tuning", + "pyspark.ml.tests", + "pyspark.ml.evaluation", + ], + blacklisted_python_implementations=[ + "PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there + ] +) + +sparkr = Module( + name="sparkr", + dependencies=[sql, mllib], + source_file_regexes=[ + "R/", + ], + should_run_r_tests=True +) + + +docs = Module( + name="docs", + dependencies=[], + source_file_regexes=[ + "docs/", + ] +) + + +ec2 = Module( + name="ec2", + dependencies=[], + source_file_regexes=[ + "ec2/", + ] +) + + +# The root module is a dummy module which is used to run all of the tests. +# No other modules should directly depend on this module. +root = Module( + name="root", + dependencies=[], + source_file_regexes=[], + # In order to run all of the tests, enable every test profile: + build_profile_flags=list(set( + itertools.chain.from_iterable(m.build_profile_flags for m in all_modules))), + sbt_test_goals=[ + "test", + ], + python_test_goals=list(itertools.chain.from_iterable(m.python_test_goals for m in all_modules)), + should_run_r_tests=True +) diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py new file mode 100644 index 0000000000000..ad9b0cc89e4ab --- /dev/null +++ b/dev/sparktestsupport/shellutils.py @@ -0,0 +1,81 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import shutil +import subprocess +import sys + + +def exit_from_command_with_retcode(cmd, retcode): + print("[error] running", ' '.join(cmd), "; received return code", retcode) + sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) + + +def rm_r(path): + """ + Given an arbitrary path, properly remove it with the correct Python construct if it exists. + From: http://stackoverflow.com/a/9559881 + """ + + if os.path.isdir(path): + shutil.rmtree(path) + elif os.path.exists(path): + os.remove(path) + + +def run_cmd(cmd): + """ + Given a command as a list of arguments will attempt to execute the command + and, on failure, print an error message and exit. + """ + + if not isinstance(cmd, list): + cmd = cmd.split() + try: + subprocess.check_call(cmd) + except subprocess.CalledProcessError as e: + exit_from_command_with_retcode(e.cmd, e.returncode) + + +def is_exe(path): + """ + Check if a given path is an executable file. + From: http://stackoverflow.com/a/377028 + """ + + return os.path.isfile(path) and os.access(path, os.X_OK) + + +def which(program): + """ + Find and return the given program by its absolute path or 'None' if the program cannot be found. + From: http://stackoverflow.com/a/377028 + """ + + fpath = os.path.split(program)[0] + + if fpath: + if is_exe(program): + return program + else: + for path in os.environ.get("PATH").split(os.pathsep): + path = path.strip('"') + exe_file = os.path.join(path, program) + if is_exe(exe_file): + return exe_file + return None diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 57049beea4dba..91ce681fbe169 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -15,6 +15,7 @@ # limitations under the License. # +import glob import os import sys from itertools import chain @@ -677,4 +678,19 @@ def test_kafka_rdd_with_leaders(self): self._validateRddResult(sendData, rdd) if __name__ == "__main__": + SPARK_HOME = os.environ["SPARK_HOME"] + kafka_assembly_dir = os.path.join(SPARK_HOME, "external/kafka-assembly") + jars = glob.glob( + os.path.join(kafka_assembly_dir, "target/scala-*/spark-streaming-kafka-assembly-*.jar")) + if not jars: + raise Exception( + ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + + "You need to build Spark with " + "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or " + "'build/mvn package' before running this test") + elif len(jars) > 1: + raise Exception(("Found multiple Spark Streaming Kafka assembly JARs in %s; please " + "remove all but one") % kafka_assembly_dir) + else: + os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars[0] unittest.main() diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 78265423682b0..17256dfc95744 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1421,7 +1421,8 @@ def do_termination_test(self, terminator): # start daemon daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py") - daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE) + python_exec = sys.executable or os.environ.get("PYSPARK_PYTHON") + daemon = Popen([python_exec, daemon_path], stdin=PIPE, stdout=PIPE) # read the port number port = read_int(daemon.stdout) diff --git a/python/run-tests b/python/run-tests index 4468fdb3f267e..24949657ed7ab 100755 --- a/python/run-tests +++ b/python/run-tests @@ -18,165 +18,7 @@ # -# Figure out where the Spark framework is installed -FWDIR="$(cd "`dirname "$0"`"; cd ../; pwd)" +FWDIR="$(cd "`dirname $0`"/..; pwd)" +cd "$FWDIR" -. "$FWDIR"/bin/load-spark-env.sh - -# CD into the python directory to find things on the right path -cd "$FWDIR/python" - -FAILED=0 -LOG_FILE=unit-tests.log -START=$(date +"%s") - -rm -f $LOG_FILE - -# Remove the metastore and warehouse directory created by the HiveContext tests in Spark SQL -rm -rf metastore warehouse - -function run_test() { - echo -en "Running test: $1 ... " | tee -a $LOG_FILE - start=$(date +"%s") - SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 > $LOG_FILE 2>&1 - - FAILED=$((PIPESTATUS[0]||$FAILED)) - - # Fail and exit on the first test failure. - if [[ $FAILED != 0 ]]; then - cat $LOG_FILE | grep -v "^[0-9][0-9]*" # filter all lines starting with a number. - echo -en "\033[31m" # Red - echo "Had test failures; see logs." - echo -en "\033[0m" # No color - exit -1 - else - now=$(date +"%s") - echo "ok ($(($now - $start))s)" - fi -} - -function run_core_tests() { - echo "Run core tests ..." - run_test "pyspark.rdd" - run_test "pyspark.context" - run_test "pyspark.conf" - run_test "pyspark.broadcast" - run_test "pyspark.accumulators" - run_test "pyspark.serializers" - run_test "pyspark.profiler" - run_test "pyspark.shuffle" - run_test "pyspark.tests" -} - -function run_sql_tests() { - echo "Run sql tests ..." - run_test "pyspark.sql.types" - run_test "pyspark.sql.context" - run_test "pyspark.sql.column" - run_test "pyspark.sql.dataframe" - run_test "pyspark.sql.group" - run_test "pyspark.sql.functions" - run_test "pyspark.sql.readwriter" - run_test "pyspark.sql.window" - run_test "pyspark.sql.tests" -} - -function run_mllib_tests() { - echo "Run mllib tests ..." - run_test "pyspark.mllib.classification" - run_test "pyspark.mllib.clustering" - run_test "pyspark.mllib.evaluation" - run_test "pyspark.mllib.feature" - run_test "pyspark.mllib.fpm" - run_test "pyspark.mllib.linalg" - run_test "pyspark.mllib.random" - run_test "pyspark.mllib.recommendation" - run_test "pyspark.mllib.regression" - run_test "pyspark.mllib.stat._statistics" - run_test "pyspark.mllib.stat.KernelDensity" - run_test "pyspark.mllib.tree" - run_test "pyspark.mllib.util" - run_test "pyspark.mllib.tests" -} - -function run_ml_tests() { - echo "Run ml tests ..." - run_test "pyspark.ml.feature" - run_test "pyspark.ml.classification" - run_test "pyspark.ml.recommendation" - run_test "pyspark.ml.regression" - run_test "pyspark.ml.tuning" - run_test "pyspark.ml.tests" - run_test "pyspark.ml.evaluation" -} - -function run_streaming_tests() { - echo "Run streaming tests ..." - - KAFKA_ASSEMBLY_DIR="$FWDIR"/external/kafka-assembly - JAR_PATH="${KAFKA_ASSEMBLY_DIR}/target/scala-${SPARK_SCALA_VERSION}" - for f in "${JAR_PATH}"/spark-streaming-kafka-assembly-*.jar; do - if [[ ! -e "$f" ]]; then - echo "Failed to find Spark Streaming Kafka assembly jar in $KAFKA_ASSEMBLY_DIR" 1>&2 - echo "You need to build Spark with " \ - "'build/sbt assembly/assembly streaming-kafka-assembly/assembly' or" \ - "'build/mvn package' before running this program" 1>&2 - exit 1 - fi - KAFKA_ASSEMBLY_JAR="$f" - done - - export PYSPARK_SUBMIT_ARGS="--jars ${KAFKA_ASSEMBLY_JAR} pyspark-shell" - run_test "pyspark.streaming.util" - run_test "pyspark.streaming.tests" -} - -echo "Running PySpark tests. Output is in python/$LOG_FILE." - -export PYSPARK_PYTHON="python" - -# Try to test with Python 2.6, since that's the minimum version that we support: -if [ $(which python2.6) ]; then - export PYSPARK_PYTHON="python2.6" -fi - -echo "Testing with Python version:" -$PYSPARK_PYTHON --version - -run_core_tests -run_sql_tests -run_mllib_tests -run_ml_tests -run_streaming_tests - -# Try to test with Python 3 -if [ $(which python3.4) ]; then - export PYSPARK_PYTHON="python3.4" - echo "Testing with Python3.4 version:" - $PYSPARK_PYTHON --version - - run_core_tests - run_sql_tests - run_mllib_tests - run_ml_tests - run_streaming_tests -fi - -# Try to test with PyPy -if [ $(which pypy) ]; then - export PYSPARK_PYTHON="pypy" - echo "Testing with PyPy version:" - $PYSPARK_PYTHON --version - - run_core_tests - run_sql_tests - run_streaming_tests -fi - -if [[ $FAILED == 0 ]]; then - now=$(date +"%s") - echo -e "\033[32mTests passed \033[0min $(($now - $START)) seconds" -fi - -# TODO: in the long-run, it would be nice to use a test runner like `nose`. -# The doctest fixtures are the current barrier to doing this. +exec python -u ./python/run-tests.py "$@" diff --git a/python/run-tests.py b/python/run-tests.py new file mode 100755 index 0000000000000..7d485b500ee3a --- /dev/null +++ b/python/run-tests.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function +from optparse import OptionParser +import os +import re +import subprocess +import sys +import time + + +# Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module +sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../dev/")) + + +from sparktestsupport import SPARK_HOME # noqa (suppress pep8 warnings) +from sparktestsupport.shellutils import which # noqa +from sparktestsupport.modules import all_modules # noqa + + +python_modules = dict((m.name, m) for m in all_modules if m.python_test_goals if m.name != 'root') + + +def print_red(text): + print('\033[31m' + text + '\033[0m') + + +LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log") + + +def run_individual_python_test(test_name, pyspark_python): + env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)} + print(" Running test: %s ..." % test_name, end='') + start_time = time.time() + with open(LOG_FILE, 'a') as log_file: + retcode = subprocess.call( + [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], + stderr=log_file, stdout=log_file, env=env) + duration = time.time() - start_time + # Exit on the first failure. + if retcode != 0: + with open(LOG_FILE, 'r') as log_file: + for line in log_file: + if not re.match('[0-9]+', line): + print(line, end='') + print_red("\nHad test failures in %s; see logs." % test_name) + exit(-1) + else: + print("ok (%is)" % duration) + + +def get_default_python_executables(): + python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)] + if "python2.6" not in python_execs: + print("WARNING: Not testing against `python2.6` because it could not be found; falling" + " back to `python` instead") + python_execs.insert(0, "python") + return python_execs + + +def parse_opts(): + parser = OptionParser( + prog="run-tests" + ) + parser.add_option( + "--python-executables", type="string", default=','.join(get_default_python_executables()), + help="A comma-separated list of Python executables to test against (default: %default)" + ) + parser.add_option( + "--modules", type="string", + default=",".join(sorted(python_modules.keys())), + help="A comma-separated list of Python modules to test (default: %default)" + ) + + (opts, args) = parser.parse_args() + if args: + parser.error("Unsupported arguments: %s" % ' '.join(args)) + return opts + + +def main(): + opts = parse_opts() + print("Running PySpark tests. Output is in python/%s" % LOG_FILE) + if os.path.exists(LOG_FILE): + os.remove(LOG_FILE) + python_execs = opts.python_executables.split(',') + modules_to_test = [] + for module_name in opts.modules.split(','): + if module_name in python_modules: + modules_to_test.append(python_modules[module_name]) + else: + print("Error: unrecognized module %s" % module_name) + sys.exit(-1) + print("Will test against the following Python executables: %s" % python_execs) + print("Will test the following Python modules: %s" % [x.name for x in modules_to_test]) + + start_time = time.time() + for python_exec in python_execs: + python_implementation = subprocess.check_output( + [python_exec, "-c", "import platform; print(platform.python_implementation())"], + universal_newlines=True).strip() + print("Testing with `%s`: " % python_exec, end='') + subprocess.call([python_exec, "--version"]) + + for module in modules_to_test: + if python_implementation not in module.blacklisted_python_implementations: + print("Running %s tests ..." % module.name) + for test_goal in module.python_test_goals: + run_individual_python_test(test_goal, python_exec) + total_duration = time.time() - start_time + print("Tests passed in %i seconds" % total_duration) + + +if __name__ == "__main__": + main() From 42db3a1c2fb6db61e01756be7fe88c4110ae638e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 27 Jun 2015 23:07:20 -0700 Subject: [PATCH 028/122] [HOTFIX] Fix pull request builder bug in #6967 --- dev/run-tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index c51b0d3010a0f..3533e0c857b9b 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -365,7 +365,7 @@ def run_python_tests(test_modules): command = [os.path.join(SPARK_HOME, "python", "run-tests")] if test_modules != [modules.root]: - command.append("--modules=%s" % ','.join(m.name for m in modules)) + command.append("--modules=%s" % ','.join(m.name for m in test_modules)) run_cmd(command) From f51004519c4c4915711fb9992e3aa4f05fd143ec Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 27 Jun 2015 23:27:52 -0700 Subject: [PATCH 029/122] [SPARK-8683] [BUILD] Depend on mockito-core instead of mockito-all Spark's tests currently depend on `mockito-all`, which bundles Hamcrest and Objenesis classes. Instead, it should depend on `mockito-core`, which declares those libraries as Maven dependencies. This is necessary in order to fix a dependency conflict that leads to a NoSuchMethodError when using certain Hamcrest matchers. See https://github.com/mockito/mockito/wiki/Declaring-mockito-dependency for more details. Author: Josh Rosen Closes #7061 from JoshRosen/mockito-core-instead-of-all and squashes the following commits: 70eccbe [Josh Rosen] Depend on mockito-core instead of mockito-all. --- LICENSE | 2 +- core/pom.xml | 2 +- extras/kinesis-asl/pom.xml | 2 +- launcher/pom.xml | 2 +- mllib/pom.xml | 2 +- network/common/pom.xml | 2 +- network/shuffle/pom.xml | 2 +- pom.xml | 2 +- repl/pom.xml | 2 +- unsafe/pom.xml | 2 +- yarn/pom.xml | 2 +- 11 files changed, 11 insertions(+), 11 deletions(-) diff --git a/LICENSE b/LICENSE index 42010d9f5f0e6..8672be55eca3e 100644 --- a/LICENSE +++ b/LICENSE @@ -948,6 +948,6 @@ The following components are provided under the MIT License. See project link fo (MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org) (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) (MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt) - (The MIT License) Mockito (org.mockito:mockito-all:1.8.5 - http://www.mockito.org) + (The MIT License) Mockito (org.mockito:mockito-core:1.8.5 - http://www.mockito.org) (MIT License) jquery (https://jquery.org/license/) (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) diff --git a/core/pom.xml b/core/pom.xml index 40a64beccdc24..565437c4861a4 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -354,7 +354,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index c6f60bc907438..c242e7a57b9ab 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -66,7 +66,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/launcher/pom.xml b/launcher/pom.xml index 48dd0d5f9106b..a853e67f5cf78 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -49,7 +49,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/mllib/pom.xml b/mllib/pom.xml index b16058ddc203a..a5db14407b4fc 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -106,7 +106,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/network/common/pom.xml b/network/common/pom.xml index a85e0a66f4a30..7dc3068ab8cb7 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -77,7 +77,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 4b5bfcb6f04bc..532463e96fbb7 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -79,7 +79,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/pom.xml b/pom.xml index 80cacb5ace2d4..1aa70240888bc 100644 --- a/pom.xml +++ b/pom.xml @@ -681,7 +681,7 @@ org.mockito - mockito-all + mockito-core 1.9.5 test diff --git a/repl/pom.xml b/repl/pom.xml index 85f7bc8ac1024..370b2bc2fa8ed 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -93,7 +93,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/unsafe/pom.xml b/unsafe/pom.xml index dd2ae6457f0b9..33782c6c66f90 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -67,7 +67,7 @@ org.mockito - mockito-all + mockito-core test diff --git a/yarn/pom.xml b/yarn/pom.xml index 644def7501dc8..2aeed98285aa8 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -107,7 +107,7 @@ org.mockito - mockito-all + mockito-core test From 52d128180166280af443fae84ac61386f3d6c500 Mon Sep 17 00:00:00 2001 From: Thomas Szymanski Date: Sun, 28 Jun 2015 01:06:49 -0700 Subject: [PATCH 030/122] [SPARK-8649] [BUILD] Mapr repository is not defined properly The previous commiter on this part was pwendell The previous url gives 404, the new one seems to be OK. This patch is added under the Apache License 2.0. The JIRA link: https://issues.apache.org/jira/browse/SPARK-8649 Author: Thomas Szymanski Closes #7054 from tszym/SPARK-8649 and squashes the following commits: bfda9c4 [Thomas Szymanski] [SPARK-8649] [BUILD] Mapr repository is not defined properly --- pom.xml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pom.xml b/pom.xml index 1aa70240888bc..00f50166b39b6 100644 --- a/pom.xml +++ b/pom.xml @@ -248,7 +248,7 @@ mapr-repo MapR Repository - http://repository.mapr.com/maven + http://repository.mapr.com/maven/ true From 77da5be6f11a7e9cb1d44f7fb97b93481505afe8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sun, 28 Jun 2015 08:03:58 -0700 Subject: [PATCH 031/122] [SPARK-8610] [SQL] Separate Row and InternalRow (part 2) Currently, we use GenericRow both for Row and InternalRow, which is confusing because it could contain Scala type also Catalyst types. This PR changes to use GenericInternalRow for InternalRow (contains catalyst types), GenericRow for Row (contains Scala types). Also fixes some incorrect use of InternalRow or Row. Author: Davies Liu Closes #7003 from davies/internalrow and squashes the following commits: d05866c [Davies Liu] fix test: rollback changes for pyspark 72878dd [Davies Liu] Merge branch 'master' of github.com:apache/spark into internalrow efd0b25 [Davies Liu] fix copy of MutableRow 87b13cf [Davies Liu] fix test d2ebd72 [Davies Liu] fix style eb4b473 [Davies Liu] mark expensive API as final bd4e99c [Davies Liu] Merge branch 'master' of github.com:apache/spark into internalrow bdfb78f [Davies Liu] remove BaseMutableRow 6f99a97 [Davies Liu] fix catalyst test defe931 [Davies Liu] remove BaseRow 288b31f [Davies Liu] Merge branch 'master' of github.com:apache/spark into internalrow 9d24350 [Davies Liu] separate Row and InternalRow (part 2) --- .../org/apache/spark/sql/BaseMutableRow.java | 68 ------ .../java/org/apache/spark/sql/BaseRow.java | 197 ------------------ .../sql/catalyst/expressions/UnsafeRow.java | 19 +- .../main/scala/org/apache/spark/sql/Row.scala | 41 ++-- .../sql/catalyst/CatalystTypeConverters.scala | 4 +- .../spark/sql/catalyst/InternalRow.scala | 40 ++-- .../sql/catalyst/expressions/Projection.scala | 50 +---- .../expressions/SpecificMutableRow.scala | 2 +- .../codegen/GenerateMutableProjection.scala | 2 +- .../codegen/GenerateProjection.scala | 16 +- .../sql/catalyst/expressions/generators.scala | 12 +- .../spark/sql/catalyst/expressions/rows.scala | 149 ++++++------- .../expressions/ExpressionEvalHelper.scala | 4 +- .../UnsafeFixedWidthAggregationMapSuite.scala | 6 +- .../org/apache/spark/sql/SQLContext.scala | 24 ++- .../spark/sql/columnar/ColumnType.scala | 70 +++---- .../columnar/InMemoryColumnarTableScan.scala | 3 +- .../sql/execution/SparkSqlSerializer.scala | 21 +- .../sql/execution/SparkSqlSerializer2.scala | 5 +- .../spark/sql/execution/SparkStrategies.scala | 3 +- .../sql/execution/joins/HashOuterJoin.scala | 4 +- .../spark/sql/execution/pythonUdfs.scala | 4 +- .../sql/execution/stat/StatFunctions.scala | 3 +- .../org/apache/spark/sql/jdbc/JDBCRDD.scala | 2 +- .../spark/sql/parquet/ParquetConverter.scala | 8 +- .../apache/spark/sql/sources/commands.scala | 6 +- .../sql/ScalaReflectionRelationSuite.scala | 7 +- .../spark/sql/sources/DDLTestSuite.scala | 2 +- .../spark/sql/sources/TableScanSuite.scala | 4 +- .../spark/sql/hive/HiveInspectors.scala | 5 +- .../apache/spark/sql/hive/TableReader.scala | 3 +- .../hive/execution/CreateTableAsSelect.scala | 14 +- .../execution/DescribeHiveTableCommand.scala | 8 +- .../hive/execution/HiveNativeCommand.scala | 8 +- .../sql/hive/execution/HiveTableScan.scala | 2 +- .../hive/execution/ScriptTransformation.scala | 7 +- .../spark/sql/hive/execution/commands.scala | 37 ++-- .../spark/sql/hive/orc/OrcRelation.scala | 10 +- .../spark/sql/hive/HiveInspectorSuite.scala | 4 +- 39 files changed, 299 insertions(+), 575 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/BaseMutableRow.java delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseMutableRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseMutableRow.java deleted file mode 100644 index acec2bf4520f2..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseMutableRow.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql; - -import org.apache.spark.sql.catalyst.expressions.MutableRow; - -public abstract class BaseMutableRow extends BaseRow implements MutableRow { - - @Override - public void update(int ordinal, Object value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setInt(int ordinal, int value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setLong(int ordinal, long value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setDouble(int ordinal, double value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setBoolean(int ordinal, boolean value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setShort(int ordinal, short value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setByte(int ordinal, byte value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setFloat(int ordinal, float value) { - throw new UnsupportedOperationException(); - } - - @Override - public void setString(int ordinal, String value) { - throw new UnsupportedOperationException(); - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java deleted file mode 100644 index 6a2356f1f9c6f..0000000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/BaseRow.java +++ /dev/null @@ -1,197 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql; - -import java.math.BigDecimal; -import java.sql.Date; -import java.sql.Timestamp; -import java.util.List; - -import scala.collection.Seq; -import scala.collection.mutable.ArraySeq; - -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.GenericRow; -import org.apache.spark.sql.types.StructType; - -public abstract class BaseRow extends InternalRow { - - @Override - final public int length() { - return size(); - } - - @Override - public boolean anyNull() { - final int n = size(); - for (int i=0; i < n; i++) { - if (isNullAt(i)) { - return true; - } - } - return false; - } - - @Override - public StructType schema() { throw new UnsupportedOperationException(); } - - @Override - final public Object apply(int i) { - return get(i); - } - - @Override - public int getInt(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public long getLong(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public float getFloat(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public double getDouble(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public byte getByte(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public short getShort(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean getBoolean(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public String getString(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public BigDecimal getDecimal(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Date getDate(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Timestamp getTimestamp(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Seq getSeq(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public List getList(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public scala.collection.Map getMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public scala.collection.immutable.Map getValuesMap(Seq fieldNames) { - throw new UnsupportedOperationException(); - } - - @Override - public java.util.Map getJavaMap(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public Row getStruct(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public T getAs(int i) { - throw new UnsupportedOperationException(); - } - - @Override - public T getAs(String fieldName) { - throw new UnsupportedOperationException(); - } - - @Override - public int fieldIndex(String name) { - throw new UnsupportedOperationException(); - } - - @Override - public InternalRow copy() { - final int n = size(); - Object[] arr = new Object[n]; - for (int i = 0; i < n; i++) { - arr[i] = get(i); - } - return new GenericRow(arr); - } - - @Override - public Seq toSeq() { - final int n = size(); - final ArraySeq values = new ArraySeq(n); - for (int i = 0; i < n; i++) { - values.update(i, get(i)); - } - return values; - } - - @Override - public String toString() { - return mkString("[", ",", "]"); - } - - @Override - public String mkString() { - return toSeq().mkString(); - } - - @Override - public String mkString(String sep) { - return toSeq().mkString(sep); - } - - @Override - public String mkString(String start, String sep, String end) { - return toSeq().mkString(start, sep, end); - } -} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index bb2f2079b40f0..11d51d90f1802 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -23,16 +23,12 @@ import java.util.HashSet; import java.util.Set; -import scala.collection.Seq; -import scala.collection.mutable.ArraySeq; - import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.BaseMutableRow; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.StructType; -import org.apache.spark.unsafe.types.UTF8String; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.bitset.BitSetMethods; +import org.apache.spark.unsafe.types.UTF8String; import static org.apache.spark.sql.types.DataTypes.*; @@ -52,7 +48,7 @@ * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ -public final class UnsafeRow extends BaseMutableRow { +public final class UnsafeRow extends MutableRow { private Object baseObject; private long baseOffset; @@ -63,6 +59,8 @@ public final class UnsafeRow extends BaseMutableRow { /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ private int numFields; + public int length() { return numFields; } + /** The width of the null tracking bit set, in bytes */ private int bitSetWidthInBytes; /** @@ -344,13 +342,4 @@ public InternalRow copy() { public boolean anyNull() { return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes); } - - @Override - public Seq toSeq() { - final ArraySeq values = new ArraySeq(numFields); - for (int fieldNumber = 0; fieldNumber < numFields; fieldNumber++) { - values.update(fieldNumber, get(fieldNumber)); - } - return values; - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index e99d5c87a44fe..0f2fd6a86d177 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -179,7 +179,7 @@ trait Row extends Serializable { def get(i: Int): Any = apply(i) /** Checks whether the value at position i is null. */ - def isNullAt(i: Int): Boolean + def isNullAt(i: Int): Boolean = apply(i) == null /** * Returns the value at position i as a primitive boolean. @@ -187,7 +187,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getBoolean(i: Int): Boolean + def getBoolean(i: Int): Boolean = getAs[Boolean](i) /** * Returns the value at position i as a primitive byte. @@ -195,7 +195,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getByte(i: Int): Byte + def getByte(i: Int): Byte = getAs[Byte](i) /** * Returns the value at position i as a primitive short. @@ -203,7 +203,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getShort(i: Int): Short + def getShort(i: Int): Short = getAs[Short](i) /** * Returns the value at position i as a primitive int. @@ -211,7 +211,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getInt(i: Int): Int + def getInt(i: Int): Int = getAs[Int](i) /** * Returns the value at position i as a primitive long. @@ -219,7 +219,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getLong(i: Int): Long + def getLong(i: Int): Long = getAs[Long](i) /** * Returns the value at position i as a primitive float. @@ -228,7 +228,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getFloat(i: Int): Float + def getFloat(i: Int): Float = getAs[Float](i) /** * Returns the value at position i as a primitive double. @@ -236,7 +236,7 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getDouble(i: Int): Double + def getDouble(i: Int): Double = getAs[Double](i) /** * Returns the value at position i as a String object. @@ -244,35 +244,35 @@ trait Row extends Serializable { * @throws ClassCastException when data type does not match. * @throws NullPointerException when value is null. */ - def getString(i: Int): String + def getString(i: Int): String = getAs[String](i) /** * Returns the value at position i of decimal type as java.math.BigDecimal. * * @throws ClassCastException when data type does not match. */ - def getDecimal(i: Int): java.math.BigDecimal = apply(i).asInstanceOf[java.math.BigDecimal] + def getDecimal(i: Int): java.math.BigDecimal = getAs[java.math.BigDecimal](i) /** * Returns the value at position i of date type as java.sql.Date. * * @throws ClassCastException when data type does not match. */ - def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date] + def getDate(i: Int): java.sql.Date = getAs[java.sql.Date](i) /** * Returns the value at position i of date type as java.sql.Timestamp. * * @throws ClassCastException when data type does not match. */ - def getTimestamp(i: Int): java.sql.Timestamp = apply(i).asInstanceOf[java.sql.Timestamp] + def getTimestamp(i: Int): java.sql.Timestamp = getAs[java.sql.Timestamp](i) /** * Returns the value at position i of array type as a Scala Seq. * * @throws ClassCastException when data type does not match. */ - def getSeq[T](i: Int): Seq[T] = apply(i).asInstanceOf[Seq[T]] + def getSeq[T](i: Int): Seq[T] = getAs[Seq[T]](i) /** * Returns the value at position i of array type as [[java.util.List]]. @@ -288,7 +288,7 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - def getMap[K, V](i: Int): scala.collection.Map[K, V] = apply(i).asInstanceOf[Map[K, V]] + def getMap[K, V](i: Int): scala.collection.Map[K, V] = getAs[Map[K, V]](i) /** * Returns the value at position i of array type as a [[java.util.Map]]. @@ -366,9 +366,18 @@ trait Row extends Serializable { /* ---------------------- utility methods for Scala ---------------------- */ /** - * Return a Scala Seq representing the row. ELements are placed in the same order in the Seq. + * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq. */ - def toSeq: Seq[Any] + def toSeq: Seq[Any] = { + val n = length + val values = new Array[Any](n) + var i = 0 + while (i < n) { + values.update(i, get(i)) + i += 1 + } + values.toSeq + } /** Displays all elements of this sequence in a string (without a separator). */ def mkString: String = toSeq.mkString diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 012f8bbecb4d3..8f63d2120ad0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -242,7 +242,7 @@ object CatalystTypeConverters { ar(idx) = converters(idx).toCatalyst(row(idx)) idx += 1 } - new GenericRowWithSchema(ar, structType) + new GenericInternalRow(ar) case p: Product => val ar = new Array[Any](structType.size) @@ -252,7 +252,7 @@ object CatalystTypeConverters { ar(idx) = converters(idx).toCatalyst(iter.next()) idx += 1 } - new GenericRowWithSchema(ar, structType) + new GenericInternalRow(ar) } override def toScala(row: InternalRow): Row = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index d7b537a9fe3bc..61a29c89d8df3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -19,14 +19,38 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.unsafe.types.UTF8String /** * An abstract class for row used internal in Spark SQL, which only contain the columns as * internal types. */ abstract class InternalRow extends Row { + + // This is only use for test + override def getString(i: Int): String = getAs[UTF8String](i).toString + + // These expensive API should not be used internally. + final override def getDecimal(i: Int): java.math.BigDecimal = + throw new UnsupportedOperationException + final override def getDate(i: Int): java.sql.Date = + throw new UnsupportedOperationException + final override def getTimestamp(i: Int): java.sql.Timestamp = + throw new UnsupportedOperationException + final override def getSeq[T](i: Int): Seq[T] = throw new UnsupportedOperationException + final override def getList[T](i: Int): java.util.List[T] = throw new UnsupportedOperationException + final override def getMap[K, V](i: Int): scala.collection.Map[K, V] = + throw new UnsupportedOperationException + final override def getJavaMap[K, V](i: Int): java.util.Map[K, V] = + throw new UnsupportedOperationException + final override def getStruct(i: Int): Row = throw new UnsupportedOperationException + final override def getAs[T](fieldName: String): T = throw new UnsupportedOperationException + final override def getValuesMap[T](fieldNames: Seq[String]): Map[String, T] = + throw new UnsupportedOperationException + // A default implementation to change the return type override def copy(): InternalRow = this + override def apply(i: Int): Any = get(i) override def equals(o: Any): Boolean = { if (!o.isInstanceOf[Row]) { @@ -93,27 +117,15 @@ abstract class InternalRow extends Row { } object InternalRow { - def unapplySeq(row: InternalRow): Some[Seq[Any]] = Some(row.toSeq) - /** * This method can be used to construct a [[Row]] with the given values. */ - def apply(values: Any*): InternalRow = new GenericRow(values.toArray) + def apply(values: Any*): InternalRow = new GenericInternalRow(values.toArray) /** * This method can be used to construct a [[Row]] from a [[Seq]] of values. */ - def fromSeq(values: Seq[Any]): InternalRow = new GenericRow(values.toArray) - - def fromTuple(tuple: Product): InternalRow = fromSeq(tuple.productIterator.toSeq) - - /** - * Merge multiple rows into a single row, one after another. - */ - def merge(rows: InternalRow*): InternalRow = { - // TODO: Improve the performance of this if used in performance critical part. - new GenericRow(rows.flatMap(_.toSeq).toArray) - } + def fromSeq(values: Seq[Any]): InternalRow = new GenericInternalRow(values.toArray) /** Returns an empty row. */ val empty = apply() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index d5967438ccb5a..fcfe83ceb863a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -36,7 +36,7 @@ class InterpretedProjection(expressions: Seq[Expression]) extends Projection { outputArray(i) = exprArray(i).eval(input) i += 1 } - new GenericRow(outputArray) + new GenericInternalRow(outputArray) } override def toString: String = s"Row => [${exprArray.mkString(",")}]" @@ -135,12 +135,6 @@ class JoinedRow extends InternalRow { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) @@ -149,7 +143,7 @@ class JoinedRow extends InternalRow { copiedValues(i) = apply(i) i += 1 } - new GenericRow(copiedValues) + new GenericInternalRow(copiedValues) } override def toString: String = { @@ -235,12 +229,6 @@ class JoinedRow2 extends InternalRow { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) @@ -249,7 +237,7 @@ class JoinedRow2 extends InternalRow { copiedValues(i) = apply(i) i += 1 } - new GenericRow(copiedValues) + new GenericInternalRow(copiedValues) } override def toString: String = { @@ -329,12 +317,6 @@ class JoinedRow3 extends InternalRow { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) @@ -343,7 +325,7 @@ class JoinedRow3 extends InternalRow { copiedValues(i) = apply(i) i += 1 } - new GenericRow(copiedValues) + new GenericInternalRow(copiedValues) } override def toString: String = { @@ -423,12 +405,6 @@ class JoinedRow4 extends InternalRow { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) @@ -437,7 +413,7 @@ class JoinedRow4 extends InternalRow { copiedValues(i) = apply(i) i += 1 } - new GenericRow(copiedValues) + new GenericInternalRow(copiedValues) } override def toString: String = { @@ -517,12 +493,6 @@ class JoinedRow5 extends InternalRow { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) @@ -531,7 +501,7 @@ class JoinedRow5 extends InternalRow { copiedValues(i) = apply(i) i += 1 } - new GenericRow(copiedValues) + new GenericInternalRow(copiedValues) } override def toString: String = { @@ -611,12 +581,6 @@ class JoinedRow6 extends InternalRow { override def getFloat(i: Int): Float = if (i < row1.length) row1.getFloat(i) else row2.getFloat(i - row1.length) - override def getString(i: Int): String = - if (i < row1.length) row1.getString(i) else row2.getString(i - row1.length) - - override def getAs[T](i: Int): T = - if (i < row1.length) row1.getAs[T](i) else row2.getAs[T](i - row1.length) - override def copy(): InternalRow = { val totalSize = row1.length + row2.length val copiedValues = new Array[Any](totalSize) @@ -625,7 +589,7 @@ class JoinedRow6 extends InternalRow { copiedValues(i) = apply(i) i += 1 } - new GenericRow(copiedValues) + new GenericInternalRow(copiedValues) } override def toString: String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 05aab34559985..53fedb531cfb2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -230,7 +230,7 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR i += 1 } - new GenericRow(newValues) + new GenericInternalRow(newValues) } override def update(ordinal: Int, value: Any) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index e75e82d380541..64ef357a4f954 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ // MutableProjection is not accessible in Java -abstract class BaseMutableProjection extends MutableProjection {} +abstract class BaseMutableProjection extends MutableProjection /** * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 624e1cf4e201a..39d32b78cc14a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.BaseMutableRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ @@ -149,6 +148,10 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { """ }.mkString("\n") + val copyColumns = expressions.zipWithIndex.map { case (e, i) => + s"""arr[$i] = c$i;""" + }.mkString("\n ") + val code = s""" public SpecificProjection generate($exprType[] expr) { return new SpecificProjection(expr); @@ -167,7 +170,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } } - final class SpecificRow extends ${typeOf[BaseMutableRow]} { + final class SpecificRow extends ${typeOf[MutableRow]} { $columns @@ -175,7 +178,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { $initColumns } - public int size() { return ${expressions.length};} + public int length() { return ${expressions.length};} protected boolean[] nullBits = new boolean[${expressions.length}]; public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } @@ -216,6 +219,13 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } return super.equals(other); } + + @Override + public InternalRow copy() { + Object[] arr = new Object[${expressions.length}]; + ${copyColumns} + return new ${typeOf[GenericInternalRow]}(arr); + } } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 356560e54cae3..7a42a1d310581 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees} import org.apache.spark.sql.types._ @@ -68,19 +69,19 @@ abstract class Generator extends Expression { */ case class UserDefinedGenerator( elementTypes: Seq[(DataType, Boolean)], - function: InternalRow => TraversableOnce[InternalRow], + function: Row => TraversableOnce[InternalRow], children: Seq[Expression]) extends Generator { @transient private[this] var inputRow: InterpretedProjection = _ - @transient private[this] var convertToScala: (InternalRow) => InternalRow = _ + @transient private[this] var convertToScala: (InternalRow) => Row = _ private def initializeConverters(): Unit = { inputRow = new InterpretedProjection(children) convertToScala = { val inputSchema = StructType(children.map(e => StructField(e.simpleString, e.dataType, true))) CatalystTypeConverters.createToScalaConverter(inputSchema) - }.asInstanceOf[(InternalRow => InternalRow)] + }.asInstanceOf[InternalRow => Row] } override def eval(input: InternalRow): TraversableOnce[InternalRow] = { @@ -118,10 +119,11 @@ case class Explode(child: Expression) child.dataType match { case ArrayType(_, _) => val inputArray = child.eval(input).asInstanceOf[Seq[Any]] - if (inputArray == null) Nil else inputArray.map(v => new GenericRow(Array(v))) + if (inputArray == null) Nil else inputArray.map(v => InternalRow(v)) case MapType(_, _, _) => val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]] - if (inputMap == null) Nil else inputMap.map { case (k, v) => new GenericRow(Array(k, v)) } + if (inputMap == null) Nil + else inputMap.map { case (k, v) => InternalRow(k, v) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 0d4c9ace5e124..dd5f2ed2d382e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.Row import org.apache.spark.sql.types.{DataType, StructType, AtomicType} import org.apache.spark.unsafe.types.UTF8String @@ -24,19 +25,32 @@ import org.apache.spark.unsafe.types.UTF8String * An extended interface to [[InternalRow]] that allows the values for each column to be updated. * Setting a value through a primitive function implicitly marks that column as not null. */ -trait MutableRow extends InternalRow { +abstract class MutableRow extends InternalRow { def setNullAt(i: Int): Unit - def update(ordinal: Int, value: Any) + def update(i: Int, value: Any) + + // default implementation (slow) + def setInt(i: Int, value: Int): Unit = { update(i, value) } + def setLong(i: Int, value: Long): Unit = { update(i, value) } + def setDouble(i: Int, value: Double): Unit = { update(i, value) } + def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) } + def setShort(i: Int, value: Short): Unit = { update(i, value) } + def setByte(i: Int, value: Byte): Unit = { update(i, value) } + def setFloat(i: Int, value: Float): Unit = { update(i, value) } + def setString(i: Int, value: String): Unit = { + update(i, UTF8String.fromString(value)) + } - def setInt(ordinal: Int, value: Int) - def setLong(ordinal: Int, value: Long) - def setDouble(ordinal: Int, value: Double) - def setBoolean(ordinal: Int, value: Boolean) - def setShort(ordinal: Int, value: Short) - def setByte(ordinal: Int, value: Byte) - def setFloat(ordinal: Int, value: Float) - def setString(ordinal: Int, value: String) + override def copy(): InternalRow = { + val arr = new Array[Any](length) + var i = 0 + while (i < length) { + arr(i) = get(i) + i += 1 + } + new GenericInternalRow(arr) + } } /** @@ -60,68 +74,57 @@ object EmptyRow extends InternalRow { } /** - * A row implementation that uses an array of objects as the underlying storage. Note that, while - * the array is not copied, and thus could technically be mutated after creation, this is not - * allowed. + * A row implementation that uses an array of objects as the underlying storage. */ -class GenericRow(protected[sql] val values: Array[Any]) extends InternalRow { - /** No-arg constructor for serialization. */ - protected def this() = this(null) +trait ArrayBackedRow { + self: Row => - def this(size: Int) = this(new Array[Any](size)) + protected val values: Array[Any] override def toSeq: Seq[Any] = values.toSeq - override def length: Int = values.length + def length: Int = values.length override def apply(i: Int): Any = values(i) - override def isNullAt(i: Int): Boolean = values(i) == null - - override def getInt(i: Int): Int = { - if (values(i) == null) sys.error("Failed to check null bit for primitive int value.") - values(i).asInstanceOf[Int] - } - - override def getLong(i: Int): Long = { - if (values(i) == null) sys.error("Failed to check null bit for primitive long value.") - values(i).asInstanceOf[Long] - } - - override def getDouble(i: Int): Double = { - if (values(i) == null) sys.error("Failed to check null bit for primitive double value.") - values(i).asInstanceOf[Double] - } - - override def getFloat(i: Int): Float = { - if (values(i) == null) sys.error("Failed to check null bit for primitive float value.") - values(i).asInstanceOf[Float] - } + def setNullAt(i: Int): Unit = { values(i) = null} - override def getBoolean(i: Int): Boolean = { - if (values(i) == null) sys.error("Failed to check null bit for primitive boolean value.") - values(i).asInstanceOf[Boolean] - } + def update(i: Int, value: Any): Unit = { values(i) = value } +} - override def getShort(i: Int): Short = { - if (values(i) == null) sys.error("Failed to check null bit for primitive short value.") - values(i).asInstanceOf[Short] - } +/** + * A row implementation that uses an array of objects as the underlying storage. Note that, while + * the array is not copied, and thus could technically be mutated after creation, this is not + * allowed. + */ +class GenericRow(protected[sql] val values: Array[Any]) extends Row with ArrayBackedRow { + /** No-arg constructor for serialization. */ + protected def this() = this(null) - override def getByte(i: Int): Byte = { - if (values(i) == null) sys.error("Failed to check null bit for primitive byte value.") - values(i).asInstanceOf[Byte] - } + def this(size: Int) = this(new Array[Any](size)) - override def getString(i: Int): String = { - values(i) match { - case null => null - case s: String => s - case utf8: UTF8String => utf8.toString - } + // This is used by test or outside + override def equals(o: Any): Boolean = o match { + case other: Row if other.length == length => + var i = 0 + while (i < length) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + val equal = (apply(i), other.apply(i)) match { + case (a: Array[Byte], b: Array[Byte]) => java.util.Arrays.equals(a, b) + case (a, b) => a == b + } + if (!equal) { + return false + } + i += 1 + } + true + case _ => false } - override def copy(): InternalRow = this + override def copy(): Row = this } class GenericRowWithSchema(values: Array[Any], override val schema: StructType) @@ -133,32 +136,30 @@ class GenericRowWithSchema(values: Array[Any], override val schema: StructType) override def fieldIndex(name: String): Int = schema.fieldIndex(name) } -class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { +/** + * A internal row implementation that uses an array of objects as the underlying storage. + * Note that, while the array is not copied, and thus could technically be mutated after creation, + * this is not allowed. + */ +class GenericInternalRow(protected[sql] val values: Array[Any]) + extends InternalRow with ArrayBackedRow { /** No-arg constructor for serialization. */ protected def this() = this(null) def this(size: Int) = this(new Array[Any](size)) - override def setBoolean(ordinal: Int, value: Boolean): Unit = { values(ordinal) = value } - override def setByte(ordinal: Int, value: Byte): Unit = { values(ordinal) = value } - override def setDouble(ordinal: Int, value: Double): Unit = { values(ordinal) = value } - override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } - override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } - override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } - override def setString(ordinal: Int, value: String): Unit = { - values(ordinal) = UTF8String.fromString(value) - } - - override def setNullAt(i: Int): Unit = { values(i) = null } + override def copy(): InternalRow = this +} - override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } +class GenericMutableRow(val values: Array[Any]) extends MutableRow with ArrayBackedRow { + /** No-arg constructor for serialization. */ + protected def this() = this(null) - override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value } + def this(size: Int) = this(new Array[Any](size)) - override def copy(): InternalRow = new GenericRow(values.clone()) + override def copy(): InternalRow = new GenericInternalRow(values.clone()) } - class RowOrdering(ordering: Seq[SortOrder]) extends Ordering[InternalRow] { def this(ordering: Seq[SortOrder], inputSchema: Seq[Attribute]) = this(ordering.map(BindReferences.bindReference(_, inputSchema))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 158f54af13802..7d95ef7f710af 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -33,7 +33,7 @@ trait ExpressionEvalHelper { self: SparkFunSuite => protected def create_row(values: Any*): InternalRow = { - new GenericRow(values.map(CatalystTypeConverters.convertToCatalyst).toArray) + InternalRow.fromSeq(values.map(CatalystTypeConverters.convertToCatalyst)) } protected def checkEvaluation( @@ -122,7 +122,7 @@ trait ExpressionEvalHelper { } val actual = plan(inputRow) - val expectedRow = new GenericRow(Array[Any](expected)) + val expectedRow = InternalRow(expected) if (actual.hashCode() != expectedRow.hashCode()) { fail( s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 7aae2bbd8a0b8..3095ccb77761b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -37,7 +37,7 @@ class UnsafeFixedWidthAggregationMapSuite private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) - private def emptyAggregationBuffer: InternalRow = new GenericRow(Array[Any](0)) + private def emptyAggregationBuffer: InternalRow = InternalRow(0) private var memoryManager: TaskMemoryManager = null @@ -84,7 +84,7 @@ class UnsafeFixedWidthAggregationMapSuite 1024, // initial capacity false // disable perf metrics ) - val groupKey = new GenericRow(Array[Any](UTF8String.fromString("cats"))) + val groupKey = InternalRow(UTF8String.fromString("cats")) // Looking up a key stores a zero-entry in the map (like Python Counters or DefaultDicts) map.getAggregationBuffer(groupKey) @@ -113,7 +113,7 @@ class UnsafeFixedWidthAggregationMapSuite val rand = new Random(42) val groupKeys: Set[String] = Seq.fill(512)(rand.nextString(1024)).toSet groupKeys.foreach { keyString => - map.getAggregationBuffer(new GenericRow(Array[Any](UTF8String.fromString(keyString)))) + map.getAggregationBuffer(InternalRow(UTF8String.fromString(keyString))) } val seenKeys: Set[String] = map.iterator().asScala.map { entry => entry.key.getString(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 5708df82de12f..8ed44ee141be5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution.{Filter, _} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** @@ -377,10 +378,11 @@ class SQLContext(@transient val sparkContext: SparkContext) val row = new SpecificMutableRow(dataType :: Nil) iter.map { v => row.setInt(0, v) - row: Row + row: InternalRow } } - DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + DataFrameHolder( + self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) } /** @@ -393,10 +395,11 @@ class SQLContext(@transient val sparkContext: SparkContext) val row = new SpecificMutableRow(dataType :: Nil) iter.map { v => row.setLong(0, v) - row: Row + row: InternalRow } } - DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + DataFrameHolder( + self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) } /** @@ -408,11 +411,12 @@ class SQLContext(@transient val sparkContext: SparkContext) val rows = data.mapPartitions { iter => val row = new SpecificMutableRow(dataType :: Nil) iter.map { v => - row.setString(0, v) - row: Row + row.update(0, UTF8String.fromString(v)) + row: InternalRow } } - DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) + DataFrameHolder( + self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil))) } } @@ -559,9 +563,9 @@ class SQLContext(@transient val sparkContext: SparkContext) (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType)) } iter.map { row => - new GenericRow( + new GenericInternalRow( methodsToConverts.map { case (e, convert) => convert(e.invoke(row)) }.toArray[Any] - ) : InternalRow + ): InternalRow } } DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this)) @@ -1065,7 +1069,7 @@ class SQLContext(@transient val sparkContext: SparkContext) } val rowRdd = convertedRdd.mapPartitions { iter => - iter.map { m => new GenericRow(m): InternalRow} + iter.map { m => new GenericInternalRow(m): InternalRow} } DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 8e21020917768..8bf2151e4de68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.types._ @@ -63,7 +63,7 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * Appends `row(ordinal)` of type T into the given ByteBuffer. Subclasses should override this * method to avoid boxing/unboxing costs whenever possible. */ - def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { append(getField(row, ordinal), buffer) } @@ -71,13 +71,13 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * Returns the size of the value `row(ordinal)`. This is used to calculate the size of variable * length types such as byte arrays and strings. */ - def actualSize(row: Row, ordinal: Int): Int = defaultSize + def actualSize(row: InternalRow, ordinal: Int): Int = defaultSize /** * Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs * whenever possible. */ - def getField(row: Row, ordinal: Int): JvmType + def getField(row: InternalRow, ordinal: Int): JvmType /** * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing @@ -89,7 +89,7 @@ private[sql] sealed abstract class ColumnType[T <: DataType, JvmType]( * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid * boxing/unboxing costs whenever possible. */ - def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { to(toOrdinal) = from(fromOrdinal) } @@ -118,7 +118,7 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { buffer.putInt(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putInt(row.getInt(ordinal)) } @@ -134,9 +134,9 @@ private[sql] object INT extends NativeColumnType(IntegerType, 0, 4) { row.setInt(ordinal, value) } - override def getField(row: Row, ordinal: Int): Int = row.getInt(ordinal) + override def getField(row: InternalRow, ordinal: Int): Int = row.getInt(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setInt(toOrdinal, from.getInt(fromOrdinal)) } } @@ -146,7 +146,7 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { buffer.putLong(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putLong(row.getLong(ordinal)) } @@ -162,9 +162,9 @@ private[sql] object LONG extends NativeColumnType(LongType, 1, 8) { row.setLong(ordinal, value) } - override def getField(row: Row, ordinal: Int): Long = row.getLong(ordinal) + override def getField(row: InternalRow, ordinal: Int): Long = row.getLong(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setLong(toOrdinal, from.getLong(fromOrdinal)) } } @@ -174,7 +174,7 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { buffer.putFloat(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putFloat(row.getFloat(ordinal)) } @@ -190,9 +190,9 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 2, 4) { row.setFloat(ordinal, value) } - override def getField(row: Row, ordinal: Int): Float = row.getFloat(ordinal) + override def getField(row: InternalRow, ordinal: Int): Float = row.getFloat(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) } } @@ -202,7 +202,7 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { buffer.putDouble(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putDouble(row.getDouble(ordinal)) } @@ -218,9 +218,9 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 3, 8) { row.setDouble(ordinal, value) } - override def getField(row: Row, ordinal: Int): Double = row.getDouble(ordinal) + override def getField(row: InternalRow, ordinal: Int): Double = row.getDouble(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) } } @@ -230,7 +230,7 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { buffer.put(if (v) 1: Byte else 0: Byte) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte) } @@ -244,9 +244,9 @@ private[sql] object BOOLEAN extends NativeColumnType(BooleanType, 4, 1) { row.setBoolean(ordinal, value) } - override def getField(row: Row, ordinal: Int): Boolean = row.getBoolean(ordinal) + override def getField(row: InternalRow, ordinal: Int): Boolean = row.getBoolean(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) } } @@ -256,7 +256,7 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { buffer.put(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.put(row.getByte(ordinal)) } @@ -272,9 +272,9 @@ private[sql] object BYTE extends NativeColumnType(ByteType, 5, 1) { row.setByte(ordinal, value) } - override def getField(row: Row, ordinal: Int): Byte = row.getByte(ordinal) + override def getField(row: InternalRow, ordinal: Int): Byte = row.getByte(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setByte(toOrdinal, from.getByte(fromOrdinal)) } } @@ -284,7 +284,7 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { buffer.putShort(v) } - override def append(row: Row, ordinal: Int, buffer: ByteBuffer): Unit = { + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { buffer.putShort(row.getShort(ordinal)) } @@ -300,15 +300,15 @@ private[sql] object SHORT extends NativeColumnType(ShortType, 6, 2) { row.setShort(ordinal, value) } - override def getField(row: Row, ordinal: Int): Short = row.getShort(ordinal) + override def getField(row: InternalRow, ordinal: Int): Short = row.getShort(ordinal) - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.setShort(toOrdinal, from.getShort(fromOrdinal)) } } private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { - override def actualSize(row: Row, ordinal: Int): Int = { + override def actualSize(row: InternalRow, ordinal: Int): Int = { row.getString(ordinal).getBytes("utf-8").length + 4 } @@ -328,11 +328,11 @@ private[sql] object STRING extends NativeColumnType(StringType, 7, 8) { row.update(ordinal, value) } - override def getField(row: Row, ordinal: Int): UTF8String = { + override def getField(row: InternalRow, ordinal: Int): UTF8String = { row(ordinal).asInstanceOf[UTF8String] } - override def copyField(from: Row, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { to.update(toOrdinal, from(fromOrdinal)) } } @@ -346,7 +346,7 @@ private[sql] object DATE extends NativeColumnType(DateType, 8, 4) { buffer.putInt(v) } - override def getField(row: Row, ordinal: Int): Int = { + override def getField(row: InternalRow, ordinal: Int): Int = { row(ordinal).asInstanceOf[Int] } @@ -364,7 +364,7 @@ private[sql] object TIMESTAMP extends NativeColumnType(TimestampType, 9, 8) { buffer.putLong(v) } - override def getField(row: Row, ordinal: Int): Long = { + override def getField(row: InternalRow, ordinal: Int): Long = { row(ordinal).asInstanceOf[Long] } @@ -387,7 +387,7 @@ private[sql] case class FIXED_DECIMAL(precision: Int, scale: Int) buffer.putLong(v.toUnscaledLong) } - override def getField(row: Row, ordinal: Int): Decimal = { + override def getField(row: InternalRow, ordinal: Int): Decimal = { row(ordinal).asInstanceOf[Decimal] } @@ -405,7 +405,7 @@ private[sql] sealed abstract class ByteArrayColumnType[T <: DataType]( defaultSize: Int) extends ColumnType[T, Array[Byte]](typeId, defaultSize) { - override def actualSize(row: Row, ordinal: Int): Int = { + override def actualSize(row: InternalRow, ordinal: Int): Int = { getField(row, ordinal).length + 4 } @@ -426,7 +426,7 @@ private[sql] object BINARY extends ByteArrayColumnType[BinaryType.type](11, 16) row(ordinal) = value } - override def getField(row: Row, ordinal: Int): Array[Byte] = { + override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { row(ordinal).asInstanceOf[Array[Byte]] } } @@ -439,7 +439,7 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) { row(ordinal) = SparkSqlSerializer.deserialize[Any](value) } - override def getField(row: Row, ordinal: Int): Array[Byte] = { + override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { SparkSqlSerializer.serialize(row(ordinal)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index 761f427b8cd0d..cb1fd4947fdbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -146,7 +146,8 @@ private[sql] case class InMemoryRelation( rowCount += 1 } - val stats = InternalRow.merge(columnBuilders.map(_.columnStats.collectedStatistics) : _*) + val stats = InternalRow.fromSeq(columnBuilders.map(_.columnStats.collectedStatistics) + .flatMap(_.toSeq)) batchStats += stats CachedBatch(columnBuilders.map(_.build().array()), stats) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index eea15aff5dbcf..b19ad4f1c563e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -20,22 +20,20 @@ package org.apache.spark.sql.execution import java.nio.ByteBuffer import java.util.{HashMap => JavaHashMap} -import org.apache.spark.sql.types.Decimal - import scala.reflect.ClassTag import com.clearspring.analytics.stream.cardinality.HyperLogLog import com.esotericsoftware.kryo.io.{Input, Output} -import com.esotericsoftware.kryo.{Serializer, Kryo} +import com.esotericsoftware.kryo.{Kryo, Serializer} import com.twitter.chill.ResourcePool -import org.apache.spark.{SparkEnv, SparkConf} -import org.apache.spark.serializer.{SerializerInstance, KryoSerializer} -import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.util.collection.OpenHashSet -import org.apache.spark.util.MutablePair - +import org.apache.spark.serializer.{KryoSerializer, SerializerInstance} +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{IntegerHashSet, LongHashSet} +import org.apache.spark.sql.types.Decimal +import org.apache.spark.util.MutablePair +import org.apache.spark.util.collection.OpenHashSet +import org.apache.spark.{SparkConf, SparkEnv} private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { override def newKryo(): Kryo = { @@ -43,6 +41,7 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co kryo.setRegistrationRequired(false) kryo.register(classOf[MutablePair[_, _]]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow]) + kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericInternalRow]) kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow]) kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog], new HyperLogLogSerializer) @@ -139,7 +138,7 @@ private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] { val iterator = hs.iterator while(iterator.hasNext) { val row = iterator.next() - rowSerializer.write(kryo, output, row.asInstanceOf[GenericRow].values) + rowSerializer.write(kryo, output, row.asInstanceOf[GenericInternalRow].values) } } @@ -150,7 +149,7 @@ private[sql] class OpenHashSetSerializer extends Serializer[OpenHashSet[_]] { var i = 0 while (i < numItems) { val row = - new GenericRow(rowSerializer.read( + new GenericInternalRow(rowSerializer.read( kryo, input, classOf[Array[Any]].asInstanceOf[Class[Any]]).asInstanceOf[Array[Any]]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 15b6936acd59b..74a22353b1d27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -26,7 +26,8 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.serializer._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, MutableRow, SpecificMutableRow} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -329,7 +330,7 @@ private[sql] object SparkSqlSerializer2 { */ def createDeserializationFunction( schema: Array[DataType], - in: DataInputStream): (MutableRow) => Row = { + in: DataInputStream): (MutableRow) => InternalRow = { if (schema == null) { (mutableRow: MutableRow) => null } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 21912cf24933e..5daf86d817586 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -210,8 +210,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - protected lazy val singleRowRdd = - sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): InternalRow), 1) + protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1) object TakeOrderedAndProject extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index bce0e8d70a57b..e41538ec1fc1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -71,8 +71,8 @@ case class HashOuterJoin( @transient private[this] lazy val DUMMY_LIST = Seq[InternalRow](null) @transient private[this] lazy val EMPTY_LIST = Seq.empty[InternalRow] - @transient private[this] lazy val leftNullRow = new GenericRow(left.output.length) - @transient private[this] lazy val rightNullRow = new GenericRow(right.output.length) + @transient private[this] lazy val leftNullRow = new GenericInternalRow(left.output.length) + @transient private[this] lazy val rightNullRow = new GenericInternalRow(right.output.length) @transient private[this] lazy val boundCondition = condition.map( newPredicate(_, left.output ++ right.output)).getOrElse((row: InternalRow) => true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala index f9c3fe92c2670..036f5d253e385 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala @@ -183,9 +183,9 @@ object EvaluatePython { }.toMap case (c, StructType(fields)) if c.getClass.isArray => - new GenericRow(c.asInstanceOf[Array[_]].zip(fields).map { + new GenericInternalRow(c.asInstanceOf[Array[_]].zip(fields).map { case (e, f) => fromJava(e, f.dataType) - }): Row + }) case (c: java.util.Calendar, DateType) => DateTimeUtils.fromJavaDate(new java.sql.Date(c.getTimeInMillis)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 252c611d02ebc..042e2c9cbb22e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String private[sql] object StatFunctions extends Logging { @@ -123,7 +124,7 @@ private[sql] object StatFunctions extends Logging { countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) } // the value of col1 is the first value, the rest are the counts - countsRow.setString(0, col1Item.toString) + countsRow.update(0, UTF8String.fromString(col1Item.toString)) countsRow }.toSeq val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala index 8b4276b2c364c..30c5f4ca3e1b2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala @@ -417,7 +417,7 @@ private[sql] class JDBCRDD( case IntegerConversion => mutableRow.setInt(i, rs.getInt(pos)) case LongConversion => mutableRow.setLong(i, rs.getLong(pos)) // TODO(davies): use getBytes for better performance, if the encoding is UTF-8 - case StringConversion => mutableRow.setString(i, rs.getString(pos)) + case StringConversion => mutableRow.update(i, UTF8String.fromString(rs.getString(pos))) case TimestampConversion => val t = rs.getTimestamp(pos) if (t != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index cf7aa44e4cd55..ae7cbf0624dc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -318,7 +318,7 @@ private[parquet] class CatalystGroupConverter( // Note: this will ever only be called in the root converter when the record has been // fully processed. Therefore it will be difficult to use mutable rows instead, since // any non-root converter never would be sure when it would be safe to re-use the buffer. - new GenericRow(current.toArray) + new GenericInternalRow(current.toArray) } override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) @@ -342,8 +342,8 @@ private[parquet] class CatalystGroupConverter( override def end(): Unit = { if (!isRootConverter) { assert(current != null) // there should be no empty groups - buffer.append(new GenericRow(current.toArray)) - parent.updateField(index, new GenericRow(buffer.toArray.asInstanceOf[Array[Any]])) + buffer.append(new GenericInternalRow(current.toArray)) + parent.updateField(index, new GenericInternalRow(buffer.toArray.asInstanceOf[Array[Any]])) } } } @@ -788,7 +788,7 @@ private[parquet] class CatalystStructConverter( // here we need to make sure to use StructScalaType // Note: we need to actually make a copy of the array since we // may be in a nested field - parent.updateField(index, new GenericRow(current.toArray)) + parent.updateField(index, new GenericInternalRow(current.toArray)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index dbb369cf45502..54c8eeb41a8ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -44,7 +44,7 @@ private[sql] case class InsertIntoDataSource( overwrite: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] val data = DataFrame(sqlContext, query) // Apply the schema of the existing table to the new data. @@ -54,7 +54,7 @@ private[sql] case class InsertIntoDataSource( // Invalidate the cache. sqlContext.cacheManager.invalidateCache(logicalRelation) - Seq.empty[InternalRow] + Seq.empty[Row] } } @@ -86,7 +86,7 @@ private[sql] case class InsertIntoHadoopFsRelation( mode: SaveMode) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { require( relation.paths.length == 1, s"Cannot write to multiple destinations: ${relation.paths.mkString(",")}") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index ece3d6fdf2af5..4cb5ba2f0d5eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions._ case class ReflectData( stringField: String, @@ -128,16 +127,16 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { Seq(data).toDF().registerTempTable("reflectComplexData") assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head === - new GenericRow(Array[Any]( + Row( Seq(1, 2, 3), Seq(1, 2, null), Map(1 -> 10L, 2 -> 20L), Map(1 -> 10L, 2 -> 20L, 3 -> null), - new GenericRow(Array[Any]( + Row( Seq(10, 20, 30), Seq(10, 20, null), Map(10 -> 100L, 20 -> 200L), Map(10 -> 100L, 20 -> 200L, 30 -> null), - new GenericRow(Array[Any](null, "abc"))))))) + Row(null, "abc")))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala index 5fc53f7012994..54e1efb6e36e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala @@ -62,7 +62,7 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo override def buildScan(): RDD[Row] = { sqlContext.sparkContext.parallelize(from to to).map { e => - InternalRow(UTF8String.fromString(s"people$e"), e * 2) + InternalRow(UTF8String.fromString(s"people$e"), e * 2): Row } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index de0ed0c0427a6..2c916f3322b6d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -90,8 +90,8 @@ case class AllDataTypesScan( Seq(Map(UTF8String.fromString(s"str_$i") -> InternalRow(i.toLong))), Map(i -> UTF8String.fromString(i.toString)), Map(Map(UTF8String.fromString(s"str_$i") -> i.toFloat) -> InternalRow(i.toLong)), - Row(i, UTF8String.fromString(i.toString)), - Row(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")), + InternalRow(i, UTF8String.fromString(i.toString)), + InternalRow(Seq(UTF8String.fromString(s"str_$i"), UTF8String.fromString(s"str_${i + 1}")), InternalRow(Seq(DateTimeUtils.fromJavaDate(new Date(1970, 1, i + 1)))))) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 864c888ab073d..a6b8ead577fb5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -336,9 +336,8 @@ private[hive] trait HiveInspectors { // currently, hive doesn't provide the ConstantStructObjectInspector case si: StructObjectInspector => val allRefs = si.getAllStructFieldRefs - new GenericRow( - allRefs.map(r => - unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector)).toArray) + InternalRow.fromSeq( + allRefs.map(r => unwrap(si.getStructFieldData(data, r), r.getFieldObjectInspector))) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 00e61e35d4354..b251a9523bed6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -34,6 +34,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} /** @@ -356,7 +357,7 @@ private[hive] object HadoopTableReader extends HiveInspectors with Logging { (value: Any, row: MutableRow, ordinal: Int) => row.setDouble(ordinal, oi.get(value)) case oi: HiveVarcharObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => - row.setString(ordinal, oi.getPrimitiveJavaObject(value).getValue) + row.update(ordinal, UTF8String.fromString(oi.getPrimitiveJavaObject(value).getValue)) case oi: HiveDecimalObjectInspector => (value: Any, row: MutableRow, ordinal: Int) => row.update(ordinal, HiveShim.toCatalystDecimal(oi, value)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 0e4a2427a9c15..84358cb73c9e3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -17,13 +17,11 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{AnalysisException, SQLContext} -import org.apache.spark.sql.catalyst.expressions.InternalRow import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.hive.client.{HiveTable, HiveColumn} -import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation, HiveMetastoreTypes} +import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes, MetastoreRelation} +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} /** * Create table and insert the query result into it. @@ -42,11 +40,11 @@ case class CreateTableAsSelect( def database: String = tableDesc.database def tableName: String = tableDesc.name - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] lazy val metastoreRelation: MetastoreRelation = { - import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat + import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.TextInputFormat @@ -89,7 +87,7 @@ case class CreateTableAsSelect( hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd } - Seq.empty[InternalRow] + Seq.empty[Row] } override def argString: String = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index a89381000ad5f..5f0ed5393d191 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -21,10 +21,10 @@ import scala.collection.JavaConversions._ import org.apache.hadoop.hive.metastore.api.FieldSchema -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.MetastoreRelation +import org.apache.spark.sql.{Row, SQLContext} /** * Implementation for "describe [extended] table". @@ -35,7 +35,7 @@ case class DescribeHiveTableCommand( override val output: Seq[Attribute], isExtended: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { // Trying to mimic the format of Hive's output. But not exactly the same. var results: Seq[(String, String, String)] = Nil @@ -57,7 +57,7 @@ case class DescribeHiveTableCommand( } results.map { case (name, dataType, comment) => - InternalRow(name, dataType, comment) + Row(name, dataType, comment) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala index 87f8e3f7fcfcc..41b645b2c9c93 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, InternalRow} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.{Row, SQLContext} private[hive] case class HiveNativeCommand(sql: String) extends RunnableCommand { @@ -29,6 +29,6 @@ case class HiveNativeCommand(sql: String) extends RunnableCommand { override def output: Seq[AttributeReference] = Seq(AttributeReference("result", StringType, nullable = false)()) - override def run(sqlContext: SQLContext): Seq[InternalRow] = - sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(InternalRow(_)) + override def run(sqlContext: SQLContext): Seq[Row] = + sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(Row(_)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 1f5e4af2e4746..f4c8c9a7e8a68 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -123,7 +123,7 @@ case class HiveTableScan( // Only partitioned values are needed here, since the predicate has already been bound to // partition key attribute references. - val row = new GenericRow(castedValues.toArray) + val row = InternalRow.fromSeq(castedValues) shouldKeep.eval(row).asInstanceOf[Boolean] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 9d8872aa47d1f..611888055d6cf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -129,11 +129,11 @@ case class ScriptTransformation( val prevLine = curLine curLine = reader.readLine() if (!ioschema.schemaLess) { - new GenericRow(CatalystTypeConverters.convertToCatalyst( + new GenericInternalRow(CatalystTypeConverters.convertToCatalyst( prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"))) .asInstanceOf[Array[Any]]) } else { - new GenericRow(CatalystTypeConverters.convertToCatalyst( + new GenericInternalRow(CatalystTypeConverters.convertToCatalyst( prevLine.split(ioschema.outputRowFormatMap("TOK_TABLEROWFORMATFIELD"), 2)) .asInstanceOf[Array[Any]]) } @@ -167,7 +167,8 @@ case class ScriptTransformation( outputStream.write(data) } else { - val writable = inputSerde.serialize(row.asInstanceOf[GenericRow].values, inputSoi) + val writable = inputSerde.serialize( + row.asInstanceOf[GenericInternalRow].values, inputSoi) prepareWritable(writable).write(dataOutputStream) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index aad58bfa2e6e0..71fa3e9c33ad9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.sources._ -import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext} -import org.apache.spark.sql.catalyst.expressions.{Attribute, InternalRow} +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -39,9 +38,9 @@ import org.apache.spark.util.Utils private[hive] case class AnalyzeTable(tableName: String) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { sqlContext.asInstanceOf[HiveContext].analyze(tableName) - Seq.empty[InternalRow] + Seq.empty[Row] } } @@ -53,7 +52,7 @@ case class DropTable( tableName: String, ifExists: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] val ifExistsClause = if (ifExists) "IF EXISTS " else "" try { @@ -70,7 +69,7 @@ case class DropTable( hiveContext.invalidateTable(tableName) hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") hiveContext.catalog.unregisterTable(Seq(tableName)) - Seq.empty[InternalRow] + Seq.empty[Row] } } @@ -83,7 +82,7 @@ case class AddJar(path: String) extends RunnableCommand { schema.toAttributes } - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] val currentClassLoader = Utils.getContextOrSparkClassLoader @@ -105,18 +104,18 @@ case class AddJar(path: String) extends RunnableCommand { // Add jar to executors hiveContext.sparkContext.addJar(path) - Seq(InternalRow(0)) + Seq(Row(0)) } } private[hive] case class AddFile(path: String) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] hiveContext.runSqlHive(s"ADD FILE $path") hiveContext.sparkContext.addFile(path) - Seq.empty[InternalRow] + Seq.empty[Row] } } @@ -129,12 +128,12 @@ case class CreateMetastoreDataSource( allowExisting: Boolean, managedIfNoPath: Boolean) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] if (hiveContext.catalog.tableExists(tableName :: Nil)) { if (allowExisting) { - return Seq.empty[InternalRow] + return Seq.empty[Row] } else { throw new AnalysisException(s"Table $tableName already exists.") } @@ -157,7 +156,7 @@ case class CreateMetastoreDataSource( optionsWithPath, isExternal) - Seq.empty[InternalRow] + Seq.empty[Row] } } @@ -170,7 +169,7 @@ case class CreateMetastoreDataSourceAsSelect( options: Map[String, String], query: LogicalPlan) extends RunnableCommand { - override def run(sqlContext: SQLContext): Seq[InternalRow] = { + override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] var createMetastoreTable = false var isExternal = true @@ -194,7 +193,7 @@ case class CreateMetastoreDataSourceAsSelect( s"Or, if you are using SQL CREATE TABLE, you need to drop $tableName first.") case SaveMode.Ignore => // Since the table already exists and the save mode is Ignore, we will just return. - return Seq.empty[InternalRow] + return Seq.empty[Row] case SaveMode.Append => // Check if the specified data source match the data source of the existing table. val resolved = ResolvedDataSource( @@ -259,6 +258,6 @@ case class CreateMetastoreDataSourceAsSelect( // Refresh the cache of the table in the catalog. hiveContext.refreshTable(tableName) - Seq.empty[InternalRow] + Seq.empty[Row] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 0fd7b3a91d6dd..300f83d914ea4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -190,7 +190,7 @@ private[sql] class OrcRelation( filters: Array[Filter], inputPaths: Array[FileStatus]): RDD[Row] = { val output = StructType(requiredColumns.map(dataSchema(_))).toAttributes - OrcTableScan(output, this, filters, inputPaths).execute() + OrcTableScan(output, this, filters, inputPaths).execute().map(_.asInstanceOf[Row]) } override def prepareJobForWrite(job: Job): OutputWriterFactory = { @@ -234,13 +234,13 @@ private[orc] case class OrcTableScan( HiveShim.appendReadColumns(conf, sortedIds, sortedNames) } - // Transform all given raw `Writable`s into `Row`s. + // Transform all given raw `Writable`s into `InternalRow`s. private def fillObject( path: String, conf: Configuration, iterator: Iterator[Writable], nonPartitionKeyAttrs: Seq[(Attribute, Int)], - mutableRow: MutableRow): Iterator[Row] = { + mutableRow: MutableRow): Iterator[InternalRow] = { val deserializer = new OrcSerde val soi = OrcFileOperator.getObjectInspector(path, Some(conf)) val (fieldRefs, fieldOrdinals) = nonPartitionKeyAttrs.map { @@ -261,11 +261,11 @@ private[orc] case class OrcTableScan( } i += 1 } - mutableRow: Row + mutableRow: InternalRow } } - def execute(): RDD[Row] = { + def execute(): RDD[InternalRow] = { val job = new Job(sqlContext.sparkContext.hadoopConfiguration) val conf = job.getConfiguration diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index aff0456b37ed5..a93acb938d5fa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -202,9 +202,9 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { val dt = StructType(dataTypes.zipWithIndex.map { case (t, idx) => StructField(s"c_$idx", t) }) - + val inspector = toInspector(dt) checkValues(row, - unwrap(wrap(Row.fromSeq(row), toInspector(dt)), toInspector(dt)).asInstanceOf[InternalRow]) + unwrap(wrap(InternalRow.fromSeq(row), inspector), inspector).asInstanceOf[InternalRow]) checkValue(null, unwrap(wrap(null, toInspector(dt)), toInspector(dt))) } From ec784381967506f8db4d6a357c0b72df25a0aa1b Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sun, 28 Jun 2015 08:29:07 -0700 Subject: [PATCH 032/122] [SPARK-8686] [SQL] DataFrame should support `where` with expression represented by String DataFrame supports `filter` function with two types of argument, `Column` and `String`. But `where` doesn't. Author: Kousuke Saruta Closes #7063 from sarutak/SPARK-8686 and squashes the following commits: 180f9a4 [Kousuke Saruta] Added test d61aec4 [Kousuke Saruta] Add "where" method with String argument to DataFrame --- .../main/scala/org/apache/spark/sql/DataFrame.scala | 12 ++++++++++++ .../scala/org/apache/spark/sql/DataFrameSuite.scala | 6 ++++++ 2 files changed, 18 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 0db4df34f9e22..d75d88307562e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -714,6 +714,18 @@ class DataFrame private[sql]( */ def where(condition: Column): DataFrame = filter(condition) + /** + * Filters rows using the given SQL expression. + * {{{ + * peopleDf.where("age > 15") + * }}} + * @group dfops + * @since 1.5.0 + */ + def where(conditionExpr: String): DataFrame = { + filter(Column(new SqlParser().parseExpression(conditionExpr))) + } + /** * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. * See [[GroupedData]] for all the available aggregate functions. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 47443a917b765..d06b9c5785527 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -160,6 +160,12 @@ class DataFrameSuite extends QueryTest { testData.collect().filter(_.getInt(0) > 90).toSeq) } + test("filterExpr using where") { + checkAnswer( + testData.where("key > 50"), + testData.collect().filter(_.getInt(0) > 50).toSeq) + } + test("repartition") { checkAnswer( testData.select('key).repartition(10).select('key), From 9ce78b4343febe87c4edd650c698cc20d38f615d Mon Sep 17 00:00:00 2001 From: "Vincent D. Warmerdam" Date: Sun, 28 Jun 2015 13:33:33 -0700 Subject: [PATCH 033/122] [SPARK-8596] [EC2] Added port for Rstudio This would otherwise need to be set manually by R users in AWS. https://issues.apache.org/jira/browse/SPARK-8596 Author: Vincent D. Warmerdam Author: vincent Closes #7068 from koaning/rstudio-port-number and squashes the following commits: ac8100d [vincent] Update spark_ec2.py ce6ad88 [Vincent D. Warmerdam] added port number for rstudio --- ec2/spark_ec2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py index e4932cfa7a4fc..18ccbc0a3edd0 100755 --- a/ec2/spark_ec2.py +++ b/ec2/spark_ec2.py @@ -505,6 +505,8 @@ def launch_cluster(conn, opts, cluster_name): master_group.authorize('tcp', 50070, 50070, authorized_address) master_group.authorize('tcp', 60070, 60070, authorized_address) master_group.authorize('tcp', 4040, 4045, authorized_address) + # Rstudio (GUI for R) needs port 8787 for web access + master_group.authorize('tcp', 8787, 8787, authorized_address) # HDFS NFS gateway requires 111,2049,4242 for tcp & udp master_group.authorize('tcp', 111, 111, authorized_address) master_group.authorize('udp', 111, 111, authorized_address) From 24fda7381171738cbbbacb5965393b660763e562 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sun, 28 Jun 2015 14:48:44 -0700 Subject: [PATCH 034/122] [SPARK-8677] [SQL] Fix non-terminating decimal expansion for decimal divide operation JIRA: https://issues.apache.org/jira/browse/SPARK-8677 Author: Liang-Chi Hsieh Closes #7056 from viirya/fix_decimal3 and squashes the following commits: 34d7419 [Liang-Chi Hsieh] Fix Non-terminating decimal expansion for decimal divide operation. --- .../scala/org/apache/spark/sql/types/Decimal.scala | 11 +++++++++-- .../apache/spark/sql/types/decimal/DecimalSuite.scala | 5 +++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index bd9823bc05424..5a169488c97eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -265,8 +265,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) - def / (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) + def / (that: Decimal): Decimal = { + if (that.isZero) { + null + } else { + // To avoid non-terminating decimal expansion problem, we turn to Java BigDecimal's divide + // with specified ROUNDING_MODE. + Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, ROUNDING_MODE.id)) + } + } def % (that: Decimal): Decimal = if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala index ccc29c0dc8c35..5f312964e5bf7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala @@ -167,4 +167,9 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { val decimal = (Decimal(Long.MaxValue, 38, 0) * Decimal(Long.MaxValue, 38, 0)).toJavaBigDecimal assert(decimal.unscaledValue.toString === "85070591730234615847396907784232501249") } + + test("fix non-terminating decimal expansion problem") { + val decimal = Decimal(1.0, 10, 3) / Decimal(3.0, 10, 3) + assert(decimal.toString === "0.333") + } } From 00a9d22bd6ef42c1e7d8dd936798b449bb3a9f67 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Sun, 28 Jun 2015 19:34:59 -0700 Subject: [PATCH 035/122] [SPARK-7845] [BUILD] Bumping default Hadoop version used in profile hadoop-1 to 1.2.1 PR #5694 reverted PR #6384 while refactoring `dev/run-tests` to `dev/run-tests.py`. Also, PR #6384 didn't bump Hadoop 1 version defined in POM. Author: Cheng Lian Closes #7062 from liancheng/spark-7845 and squashes the following commits: c088b72 [Cheng Lian] Bumping default Hadoop version used in profile hadoop-1 to 1.2.1 --- dev/run-tests.py | 2 +- pom.xml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index 3533e0c857b9b..eb79a2a502707 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -261,7 +261,7 @@ def get_hadoop_profiles(hadoop_version): """ sbt_maven_hadoop_profiles = { - "hadoop1.0": ["-Phadoop-1", "-Dhadoop.version=1.0.4"], + "hadoop1.0": ["-Phadoop-1", "-Dhadoop.version=1.2.1"], "hadoop2.0": ["-Phadoop-1", "-Dhadoop.version=2.0.0-mr1-cdh4.1.1"], "hadoop2.2": ["-Pyarn", "-Phadoop-2.2"], "hadoop2.3": ["-Pyarn", "-Phadoop-2.3", "-Dhadoop.version=2.3.0"], diff --git a/pom.xml b/pom.xml index 00f50166b39b6..4c18bd5e42c87 100644 --- a/pom.xml +++ b/pom.xml @@ -1686,7 +1686,7 @@ hadoop-1 - 1.0.4 + 1.2.1 2.4.1 0.98.7-hadoop1 hadoop1 From 25f574eb9a3cb9b93b7d9194a8ec16e00ce2c036 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Sun, 28 Jun 2015 22:26:07 -0700 Subject: [PATCH 036/122] [SPARK-7212] [MLLIB] Add sequence learning flag Support mining of ordered frequent item sequences. Author: Feynman Liang Closes #6997 from feynmanliang/fp-sequence and squashes the following commits: 7c14e15 [Feynman Liang] Improve scalatests with R code and Seq 0d3e4b6 [Feynman Liang] Fix python test ce987cb [Feynman Liang] Backwards compatibility aux constructor 34ef8f2 [Feynman Liang] Fix failing test due to reverse orderering f04bd50 [Feynman Liang] Naming, add ordered to FreqItemsets, test ordering using Seq 648d4d4 [Feynman Liang] Test case for frequent item sequences 252a36a [Feynman Liang] Add sequence learning flag --- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 38 +++++++++++--- .../spark/mllib/fpm/FPGrowthSuite.scala | 52 ++++++++++++++++++- python/pyspark/mllib/fpm.py | 4 +- 3 files changed, 82 insertions(+), 12 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index efa8459d3cdba..abac08022ea47 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -36,7 +36,7 @@ import org.apache.spark.storage.StorageLevel * :: Experimental :: * * Model trained by [[FPGrowth]], which holds frequent itemsets. - * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]] + * @param freqItemsets frequent itemsets, which is an RDD of [[FreqItemset]] * @tparam Item item type */ @Experimental @@ -62,13 +62,14 @@ class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) ex @Experimental class FPGrowth private ( private var minSupport: Double, - private var numPartitions: Int) extends Logging with Serializable { + private var numPartitions: Int, + private var ordered: Boolean) extends Logging with Serializable { /** * Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same - * as the input data}. + * as the input data, ordered: `false`}. */ - def this() = this(0.3, -1) + def this() = this(0.3, -1, false) /** * Sets the minimal support level (default: `0.3`). @@ -86,6 +87,15 @@ class FPGrowth private ( this } + /** + * Indicates whether to mine itemsets (unordered) or sequences (ordered) (default: false, mine + * itemsets). + */ + def setOrdered(ordered: Boolean): this.type = { + this.ordered = ordered + this + } + /** * Computes an FP-Growth model that contains frequent itemsets. * @param data input data set, each element contains a transaction @@ -155,7 +165,7 @@ class FPGrowth private ( .flatMap { case (part, tree) => tree.extract(minCount, x => partitioner.getPartition(x) == part) }.map { case (ranks, count) => - new FreqItemset(ranks.map(i => freqItems(i)).toArray, count) + new FreqItemset(ranks.map(i => freqItems(i)).reverse.toArray, count, ordered) } } @@ -171,9 +181,12 @@ class FPGrowth private ( itemToRank: Map[Item, Int], partitioner: Partitioner): mutable.Map[Int, Array[Int]] = { val output = mutable.Map.empty[Int, Array[Int]] - // Filter the basket by frequent items pattern and sort their ranks. + // Filter the basket by frequent items pattern val filtered = transaction.flatMap(itemToRank.get) - ju.Arrays.sort(filtered) + if (!this.ordered) { + ju.Arrays.sort(filtered) + } + // Generate conditional transactions val n = filtered.length var i = n - 1 while (i >= 0) { @@ -198,9 +211,18 @@ object FPGrowth { * Frequent itemset. * @param items items in this itemset. Java users should call [[FreqItemset#javaItems]] instead. * @param freq frequency + * @param ordered indicates if items represents an itemset (false) or sequence (true) * @tparam Item item type */ - class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable { + class FreqItemset[Item](val items: Array[Item], val freq: Long, val ordered: Boolean) + extends Serializable { + + /** + * Auxillary constructor, assumes unordered by default. + */ + def this(items: Array[Item], freq: Long) { + this(items, freq, false) + } /** * Returns items in a Java List. diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index 66ae3543ecc4e..1a8a1e79f2810 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { - test("FP-Growth using String type") { + test("FP-Growth frequent itemsets using String type") { val transactions = Seq( "r z h k p", "z y x w v u t s", @@ -38,12 +38,14 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model6 = fpg .setMinSupport(0.9) .setNumPartitions(1) + .setOrdered(false) .run(rdd) assert(model6.freqItemsets.count() === 0) val model3 = fpg .setMinSupport(0.5) .setNumPartitions(2) + .setOrdered(false) .run(rdd) val freqItemsets3 = model3.freqItemsets.collect().map { itemset => (itemset.items.toSet, itemset.freq) @@ -61,17 +63,59 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model2 = fpg .setMinSupport(0.3) .setNumPartitions(4) + .setOrdered(false) .run(rdd) assert(model2.freqItemsets.count() === 54) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) + .setOrdered(false) .run(rdd) assert(model1.freqItemsets.count() === 625) } - test("FP-Growth using Int type") { + test("FP-Growth frequent sequences using String type"){ + val transactions = Seq( + "r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p") + .map(_.split(" ")) + val rdd = sc.parallelize(transactions, 2).cache() + + val fpg = new FPGrowth() + + val model1 = fpg + .setMinSupport(0.5) + .setNumPartitions(2) + .setOrdered(true) + .run(rdd) + + /* + Use the following R code to verify association rules using arulesSequences package. + + data = read_baskets("path", info = c("sequenceID","eventID","SIZE")) + freqItemSeq = cspade(data, parameter = list(support = 0.5)) + resSeq = as(freqItemSeq, "data.frame") + resSeq$support = resSeq$support * length(transactions) + names(resSeq)[names(resSeq) == "support"] = "freq" + resSeq + */ + val expected = Set( + (Seq("r"), 3L), (Seq("s"), 3L), (Seq("t"), 3L), (Seq("x"), 4L), (Seq("y"), 3L), + (Seq("z"), 5L), (Seq("z", "y"), 3L), (Seq("x", "t"), 3L), (Seq("y", "t"), 3L), + (Seq("z", "t"), 3L), (Seq("z", "y", "t"), 3L) + ) + val freqItemseqs1 = model1.freqItemsets.collect().map { itemset => + (itemset.items.toSeq, itemset.freq) + }.toSet + assert(freqItemseqs1 == expected) + } + + test("FP-Growth frequent itemsets using Int type") { val transactions = Seq( "1 2 3", "1 2 3 4", @@ -88,12 +132,14 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model6 = fpg .setMinSupport(0.9) .setNumPartitions(1) + .setOrdered(false) .run(rdd) assert(model6.freqItemsets.count() === 0) val model3 = fpg .setMinSupport(0.5) .setNumPartitions(2) + .setOrdered(false) .run(rdd) assert(model3.freqItemsets.first().items.getClass === Array(1).getClass, "frequent itemsets should use primitive arrays") @@ -109,12 +155,14 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model2 = fpg .setMinSupport(0.3) .setNumPartitions(4) + .setOrdered(false) .run(rdd) assert(model2.freqItemsets.count() === 15) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) + .setOrdered(false) .run(rdd) assert(model1.freqItemsets.count() === 65) } diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index bdc4a132b1b18..b7f00d60069e6 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -39,8 +39,8 @@ class FPGrowthModel(JavaModelWrapper): >>> data = [["a", "b", "c"], ["a", "b", "d", "e"], ["a", "c", "e"], ["a", "c", "f"]] >>> rdd = sc.parallelize(data, 2) >>> model = FPGrowth.train(rdd, 0.6, 2) - >>> sorted(model.freqItemsets().collect()) - [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... + >>> sorted(model.freqItemsets().collect(), key=lambda x: x.items) + [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'a', u'c'], freq=3), ... """ def freqItemsets(self): From dfde31da5ce30e0d44cad4fb6618b44d5353d946 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 28 Jun 2015 22:38:04 -0700 Subject: [PATCH 037/122] [SPARK-5962] [MLLIB] Python support for Power Iteration Clustering Python support for Power Iteration Clustering https://issues.apache.org/jira/browse/SPARK-5962 Author: Yanbo Liang Closes #6992 from yanboliang/pyspark-pic and squashes the following commits: 6b03d82 [Yanbo Liang] address comments 4be4423 [Yanbo Liang] Python support for Power Iteration Clustering --- ...PowerIterationClusteringModelWrapper.scala | 32 ++++++ .../mllib/api/python/PythonMLLibAPI.scala | 27 +++++ python/pyspark/mllib/clustering.py | 98 ++++++++++++++++++- 3 files changed, 154 insertions(+), 3 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala new file mode 100644 index 0000000000000..bc6041b221732 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.api.python + +import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.clustering.PowerIterationClusteringModel + +/** + * A Wrapper of PowerIterationClusteringModel to provide helper method for Python + */ +private[python] class PowerIterationClusteringModelWrapper(model: PowerIterationClusteringModel) + extends PowerIterationClusteringModel(model.k, model.assignments) { + + def getAssignments: RDD[Array[Any]] = { + model.assignments.map(x => Array(x.id, x.cluster)) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index b16903a8d515c..a66a404d5c846 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -406,6 +406,33 @@ private[python] class PythonMLLibAPI extends Serializable { model.predictSoft(data).map(Vectors.dense) } + /** + * Java stub for Python mllib PowerIterationClustering.run(). This stub returns a + * handle to the Java object instead of the content of the Java object. Extra care + * needs to be taken in the Python code to ensure it gets freed on exit; see the + * Py4J documentation. + * @param data an RDD of (i, j, s,,ij,,) tuples representing the affinity matrix. + * @param k number of clusters. + * @param maxIterations maximum number of iterations of the power iteration loop. + * @param initMode the initialization mode. This can be either "random" to use + * a random vector as vertex properties, or "degree" to use + * normalized sum similarities. Default: random. + */ + def trainPowerIterationClusteringModel( + data: JavaRDD[Vector], + k: Int, + maxIterations: Int, + initMode: String): PowerIterationClusteringModel = { + + val pic = new PowerIterationClustering() + .setK(k) + .setMaxIterations(maxIterations) + .setInitializationMode(initMode) + + val model = pic.run(data.rdd.map(v => (v(0).toLong, v(1).toLong, v(2)))) + new PowerIterationClusteringModelWrapper(model) + } + /** * Java stub for Python mllib ALS.train(). This stub returns a handle * to the Java object instead of the content of the Java object. Extra care diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 8bc0654c76ca3..e3c8a24c4a751 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -25,15 +25,18 @@ from numpy import array, random, tile +from collections import namedtuple + from pyspark import SparkContext from pyspark.rdd import RDD, ignore_unicode_prefix -from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _py2java, _java2py +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, callJavaFunc, _py2java, _java2py from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector from pyspark.mllib.stat.distribution import MultivariateGaussian -from pyspark.mllib.util import Saveable, Loader, inherit_doc +from pyspark.mllib.util import Saveable, Loader, inherit_doc, JavaLoader, JavaSaveable from pyspark.streaming import DStream __all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture', + 'PowerIterationClusteringModel', 'PowerIterationClustering', 'StreamingKMeans', 'StreamingKMeansModel'] @@ -272,6 +275,94 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia return GaussianMixtureModel(weight, mvg_obj) +class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader): + + """ + .. note:: Experimental + + Model produced by [[PowerIterationClustering]]. + + >>> data = [(0, 1, 1.0), (0, 2, 1.0), (1, 3, 1.0), (2, 3, 1.0), + ... (0, 3, 1.0), (1, 2, 1.0), (0, 4, 0.1)] + >>> rdd = sc.parallelize(data, 2) + >>> model = PowerIterationClustering.train(rdd, 2, 100) + >>> model.k + 2 + >>> sorted(model.assignments().collect()) + [Assignment(id=0, cluster=1), Assignment(id=1, cluster=0), ... + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = PowerIterationClusteringModel.load(sc, path) + >>> sameModel.k + 2 + >>> sorted(sameModel.assignments().collect()) + [Assignment(id=0, cluster=1), Assignment(id=1, cluster=0), ... + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass + """ + + @property + def k(self): + """ + Returns the number of clusters. + """ + return self.call("k") + + def assignments(self): + """ + Returns the cluster assignments of this model. + """ + return self.call("getAssignments").map( + lambda x: (PowerIterationClustering.Assignment(*x))) + + @classmethod + def load(cls, sc, path): + model = cls._load_java(sc, path) + wrapper = sc._jvm.PowerIterationClusteringModelWrapper(model) + return PowerIterationClusteringModel(wrapper) + + +class PowerIterationClustering(object): + """ + .. note:: Experimental + + Power Iteration Clustering (PIC), a scalable graph clustering algorithm + developed by [[http://www.icml2010.org/papers/387.pdf Lin and Cohen]]. + From the abstract: PIC finds a very low-dimensional embedding of a + dataset using truncated power iteration on a normalized pair-wise + similarity matrix of the data. + """ + + @classmethod + def train(cls, rdd, k, maxIterations=100, initMode="random"): + """ + :param rdd: an RDD of (i, j, s,,ij,,) tuples representing the + affinity matrix, which is the matrix A in the PIC paper. + The similarity s,,ij,, must be nonnegative. + This is a symmetric matrix and hence s,,ij,, = s,,ji,,. + For any (i, j) with nonzero similarity, there should be + either (i, j, s,,ij,,) or (j, i, s,,ji,,) in the input. + Tuples with i = j are ignored, because we assume + s,,ij,, = 0.0. + :param k: Number of clusters. + :param maxIterations: Maximum number of iterations of the + PIC algorithm. + :param initMode: Initialization mode. + """ + model = callMLlibFunc("trainPowerIterationClusteringModel", + rdd.map(_convert_to_vector), int(k), int(maxIterations), initMode) + return PowerIterationClusteringModel(model) + + class Assignment(namedtuple("Assignment", ["id", "cluster"])): + """ + Represents an (id, cluster) tuple. + """ + + class StreamingKMeansModel(KMeansModel): """ .. note:: Experimental @@ -466,7 +557,8 @@ def predictOnValues(self, dstream): def _test(): import doctest - globs = globals().copy() + import pyspark.mllib.clustering + globs = pyspark.mllib.clustering.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) globs['sc'].stop() From 0b10662fef11a56f82144b4953d457738e6961ae Mon Sep 17 00:00:00 2001 From: BenFradet Date: Sun, 28 Jun 2015 22:43:47 -0700 Subject: [PATCH 038/122] [SPARK-8575] [SQL] Deprecate callUDF in favor of udf Follow up of [SPARK-8356](https://issues.apache.org/jira/browse/SPARK-8356) and #6902. Removes the unit test for the now deprecated ```callUdf``` Unit test in SQLQuerySuite now uses ```udf``` instead of ```callUDF``` Replaced ```callUDF``` by ```udf``` where possible in mllib Author: BenFradet Closes #6993 from BenFradet/SPARK-8575 and squashes the following commits: 26f5a7a [BenFradet] 2 spaces instead of 1 1ddb452 [BenFradet] renamed initUDF in order to be consistent in OneVsRest 48ca15e [BenFradet] used vector type tag for udf call in VectorIndexer 0ebd0da [BenFradet] replace the now deprecated callUDF by udf in VectorIndexer 8013409 [BenFradet] replaced the now deprecated callUDF by udf in Predictor 94345b5 [BenFradet] unifomized udf calls in ProbabilisticClassifier 1305492 [BenFradet] uniformized udf calls in Classifier a672228 [BenFradet] uniformized udf calls in OneVsRest 49e4904 [BenFradet] Revert "removal of the unit test for the now deprecated callUdf" bbdeaf3 [BenFradet] fixed syntax for init udf in OneVsRest fe2a10b [BenFradet] callUDF => udf in ProbabilisticClassifier 0ea30b3 [BenFradet] callUDF => udf in Classifier where possible 197ec82 [BenFradet] callUDF => udf in OneVsRest 84d6780 [BenFradet] modified unit test in SQLQuerySuite to use udf instead of callUDF 477709f [BenFradet] removal of the unit test for the now deprecated callUdf --- .../scala/org/apache/spark/ml/Predictor.scala | 9 ++++--- .../spark/ml/classification/Classifier.scala | 13 ++++++--- .../spark/ml/classification/OneVsRest.scala | 27 +++++++++---------- .../ProbabilisticClassifier.scala | 22 ++++++++++----- .../spark/ml/feature/VectorIndexer.scala | 5 ++-- .../org/apache/spark/sql/SQLQuerySuite.scala | 5 ++-- 6 files changed, 46 insertions(+), 35 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index edaa2afb790e6..333b42711ec52 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -122,9 +122,7 @@ abstract class Predictor[ */ protected def extractLabeledPoints(dataset: DataFrame): RDD[LabeledPoint] = { dataset.select($(labelCol), $(featuresCol)) - .map { case Row(label: Double, features: Vector) => - LabeledPoint(label, features) - } + .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) } } } @@ -171,7 +169,10 @@ abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) if ($(predictionCol).nonEmpty) { - dataset.withColumn($(predictionCol), callUDF(predict _, DoubleType, col($(featuresCol)))) + val predictUDF = udf { (features: Any) => + predict(features.asInstanceOf[FeaturesType]) + } + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) } else { this.logWarning(s"$uid: Predictor.transform() was called as NOOP" + " since no output columns were set.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 14c285dbfc54a..85c097bc64a4f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -102,15 +102,20 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur var outputData = dataset var numColsOutput = 0 if (getRawPredictionCol != "") { - outputData = outputData.withColumn(getRawPredictionCol, - callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol))) + val predictRawUDF = udf { (features: Any) => + predictRaw(features.asInstanceOf[FeaturesType]) + } + outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol))) numColsOutput += 1 } if (getPredictionCol != "") { val predUDF = if (getRawPredictionCol != "") { - callUDF(raw2prediction _, DoubleType, col(getRawPredictionCol)) + udf(raw2prediction _).apply(col(getRawPredictionCol)) } else { - callUDF(predict _, DoubleType, col(getFeaturesCol)) + val predictUDF = udf { (features: Any) => + predict(features.asInstanceOf[FeaturesType]) + } + predictUDF(col(getFeaturesCol)) } outputData = outputData.withColumn(getPredictionCol, predUDF) numColsOutput += 1 diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index b657882f8ad3f..ea757c5e40c76 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -88,9 +88,9 @@ final class OneVsRestModel private[ml] ( // add an accumulator column to store predictions of all the models val accColName = "mbc$acc" + UUID.randomUUID().toString - val init: () => Map[Int, Double] = () => {Map()} + val initUDF = udf { () => Map[Int, Double]() } val mapType = MapType(IntegerType, DoubleType, valueContainsNull = false) - val newDataset = dataset.withColumn(accColName, callUDF(init, mapType)) + val newDataset = dataset.withColumn(accColName, initUDF()) // persist if underlying dataset is not persistent. val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE @@ -106,13 +106,12 @@ final class OneVsRestModel private[ml] ( // add temporary column to store intermediate scores and update val tmpColName = "mbc$tmp" + UUID.randomUUID().toString - val update: (Map[Int, Double], Vector) => Map[Int, Double] = - (predictions: Map[Int, Double], prediction: Vector) => { - predictions + ((index, prediction(1))) - } - val updateUdf = callUDF(update, mapType, col(accColName), col(rawPredictionCol)) + val updateUDF = udf { (predictions: Map[Int, Double], prediction: Vector) => + predictions + ((index, prediction(1))) + } val transformedDataset = model.transform(df).select(columns : _*) - val updatedDataset = transformedDataset.withColumn(tmpColName, updateUdf) + val updatedDataset = transformedDataset + .withColumn(tmpColName, updateUDF(col(accColName), col(rawPredictionCol))) val newColumns = origCols ++ List(col(tmpColName)) // switch out the intermediate column with the accumulator column @@ -124,13 +123,13 @@ final class OneVsRestModel private[ml] ( } // output the index of the classifier with highest confidence as prediction - val label: Map[Int, Double] => Double = (predictions: Map[Int, Double]) => { + val labelUDF = udf { (predictions: Map[Int, Double]) => predictions.maxBy(_._2)._1.toDouble } // output label and label metadata as prediction - val labelUdf = callUDF(label, DoubleType, col(accColName)) - aggregatedDataset.withColumn($(predictionCol), labelUdf.as($(predictionCol), labelMetadata)) + aggregatedDataset + .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata)) .drop(accColName) } @@ -185,17 +184,15 @@ final class OneVsRest(override val uid: String) // create k columns, one for each binary classifier. val models = Range(0, numClasses).par.map { index => - - val label: Double => Double = (label: Double) => { + val labelUDF = udf { (label: Double) => if (label.toInt == index) 1.0 else 0.0 } // generate new label metadata for the binary problem. // TODO: use when ... otherwise after SPARK-7321 is merged - val labelUDF = callUDF(label, DoubleType, col($(labelCol))) val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata() val labelColName = "mc2b$" + index - val labelUDFWithNewMeta = labelUDF.as(labelColName, newLabelMeta) + val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta) val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) val classifier = getClassifier classifier.fit(trainingDataset, classifier.labelCol -> labelColName) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 330ae2938f4e0..38e832372698c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -98,26 +98,34 @@ private[spark] abstract class ProbabilisticClassificationModel[ var outputData = dataset var numColsOutput = 0 if ($(rawPredictionCol).nonEmpty) { - outputData = outputData.withColumn(getRawPredictionCol, - callUDF(predictRaw _, new VectorUDT, col(getFeaturesCol))) + val predictRawUDF = udf { (features: Any) => + predictRaw(features.asInstanceOf[FeaturesType]) + } + outputData = outputData.withColumn(getRawPredictionCol, predictRawUDF(col(getFeaturesCol))) numColsOutput += 1 } if ($(probabilityCol).nonEmpty) { val probUDF = if ($(rawPredictionCol).nonEmpty) { - callUDF(raw2probability _, new VectorUDT, col($(rawPredictionCol))) + udf(raw2probability _).apply(col($(rawPredictionCol))) } else { - callUDF(predictProbability _, new VectorUDT, col($(featuresCol))) + val probabilityUDF = udf { (features: Any) => + predictProbability(features.asInstanceOf[FeaturesType]) + } + probabilityUDF(col($(featuresCol))) } outputData = outputData.withColumn($(probabilityCol), probUDF) numColsOutput += 1 } if ($(predictionCol).nonEmpty) { val predUDF = if ($(rawPredictionCol).nonEmpty) { - callUDF(raw2prediction _, DoubleType, col($(rawPredictionCol))) + udf(raw2prediction _).apply(col($(rawPredictionCol))) } else if ($(probabilityCol).nonEmpty) { - callUDF(probability2prediction _, DoubleType, col($(probabilityCol))) + udf(probability2prediction _).apply(col($(probabilityCol))) } else { - callUDF(predict _, DoubleType, col($(featuresCol))) + val predictUDF = udf { (features: Any) => + predict(features.asInstanceOf[FeaturesType]) + } + predictUDF(col($(featuresCol))) } outputData = outputData.withColumn($(predictionCol), predUDF) numColsOutput += 1 diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index f4854a5e4b7b7..c73bdccdef5fa 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -30,7 +30,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.functions.callUDF +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.util.collection.OpenHashSet @@ -339,7 +339,8 @@ class VectorIndexerModel private[ml] ( override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) val newField = prepOutputField(dataset.schema) - val newCol = callUDF(transformFunc, new VectorUDT, dataset($(inputCol))) + val transformUDF = udf { (vector: Vector) => transformFunc(vector) } + val newCol = transformUDF(dataset($(inputCol))) dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 73bc6c999164e..22c54e43c1d16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -137,13 +137,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("SPARK-7158 collect and take return different results") { import java.util.UUID - import org.apache.spark.sql.types._ val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index") // we except the id is materialized once - def id: () => String = () => { UUID.randomUUID().toString() } + val idUdf = udf(() => UUID.randomUUID().toString) - val dfWithId = df.withColumn("id", callUDF(id, StringType)) + val dfWithId = df.withColumn("id", idUdf()) // Make a new DataFrame (actually the same reference to the old one) val cached = dfWithId.cache() // Trigger the cache From ac2e17b01c0843d928a363d2cc4faf57ec8c8b47 Mon Sep 17 00:00:00 2001 From: Cheolsoo Park Date: Mon, 29 Jun 2015 00:13:39 -0700 Subject: [PATCH 039/122] [SPARK-8355] [SQL] Python DataFrameReader/Writer should mirror Scala I compared PySpark DataFrameReader/Writer against Scala ones. `Option` function is missing in both reader and writer, but the rest seems to all match. I added `Option` to reader and writer and updated the `pyspark-sql` test. Author: Cheolsoo Park Closes #7078 from piaozhexiu/SPARK-8355 and squashes the following commits: c63d419 [Cheolsoo Park] Fix version 524e0aa [Cheolsoo Park] Add option function to df reader and writer --- python/pyspark/sql/readwriter.py | 14 ++++++++++++++ python/pyspark/sql/tests.py | 1 + 2 files changed, 15 insertions(+) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 1b7bc0f9a12be..c4cc62e82a160 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -73,6 +73,13 @@ def schema(self, schema): self._jreader = self._jreader.schema(jschema) return self + @since(1.5) + def option(self, key, value): + """Adds an input option for the underlying data source. + """ + self._jreader = self._jreader.option(key, value) + return self + @since(1.4) def options(self, **options): """Adds input options for the underlying data source. @@ -235,6 +242,13 @@ def format(self, source): self._jwrite = self._jwrite.format(source) return self + @since(1.5) + def option(self, key, value): + """Adds an output option for the underlying data source. + """ + self._jwrite = self._jwrite.option(key, value) + return self + @since(1.4) def options(self, **options): """Adds output options for the underlying data source. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e6a434e4b2dff..ffee43a94baba 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -564,6 +564,7 @@ def test_save_and_load_builder(self): self.assertEqual(sorted(df.collect()), sorted(actual.collect())) df.write.mode("overwrite").options(noUse="this options will not be used in save.")\ + .option("noUse", "this option will not be used in save.")\ .format("json").save(path=tmpPath) actual =\ self.sqlCtx.read.format("json")\ From 660c6cec75dc165cf5d62cdc1b0951bdb93df365 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Jun 2015 00:22:44 -0700 Subject: [PATCH 040/122] [SPARK-8698] partitionBy in Python DataFrame reader/writer interface should not default to empty tuple. Author: Reynold Xin Closes #7079 from rxin/SPARK-8698 and squashes the following commits: 8513e1c [Reynold Xin] [SPARK-8698] partitionBy in Python DataFrame reader/writer interface should not default to empty tuple. --- python/pyspark/sql/readwriter.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index c4cc62e82a160..882a03090ec13 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -270,12 +270,11 @@ def partitionBy(self, *cols): """ if len(cols) == 1 and isinstance(cols[0], (list, tuple)): cols = cols[0] - if len(cols) > 0: - self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) + self._jwrite = self._jwrite.partitionBy(_to_seq(self._sqlContext._sc, cols)) return self @since(1.4) - def save(self, path=None, format=None, mode=None, partitionBy=(), **options): + def save(self, path=None, format=None, mode=None, partitionBy=None, **options): """Saves the contents of the :class:`DataFrame` to a data source. The data source is specified by the ``format`` and a set of ``options``. @@ -295,7 +294,9 @@ def save(self, path=None, format=None, mode=None, partitionBy=(), **options): >>> df.write.mode('append').parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ - self.partitionBy(partitionBy).mode(mode).options(**options) + self.mode(mode).options(**options) + if partitionBy is not None: + self.partitionBy(partitionBy) if format is not None: self.format(format) if path is None: @@ -315,7 +316,7 @@ def insertInto(self, tableName, overwrite=False): self._jwrite.mode("overwrite" if overwrite else "append").insertInto(tableName) @since(1.4) - def saveAsTable(self, name, format=None, mode=None, partitionBy=(), **options): + def saveAsTable(self, name, format=None, mode=None, partitionBy=None, **options): """Saves the content of the :class:`DataFrame` as the specified table. In the case the table already exists, behavior of this function depends on the @@ -334,7 +335,9 @@ def saveAsTable(self, name, format=None, mode=None, partitionBy=(), **options): :param partitionBy: names of partitioning columns :param options: all other string options """ - self.partitionBy(partitionBy).mode(mode).options(**options) + self.mode(mode).options(**options) + if partitionBy is not None: + self.partitionBy(partitionBy) if format is not None: self.format(format) self._jwrite.saveAsTable(name) @@ -356,7 +359,7 @@ def json(self, path, mode=None): self.mode(mode)._jwrite.json(path) @since(1.4) - def parquet(self, path, mode=None, partitionBy=()): + def parquet(self, path, mode=None, partitionBy=None): """Saves the content of the :class:`DataFrame` in Parquet format at the specified path. :param path: the path in any Hadoop supported file system @@ -370,7 +373,9 @@ def parquet(self, path, mode=None, partitionBy=()): >>> df.write.parquet(os.path.join(tempfile.mkdtemp(), 'data')) """ - self.partitionBy(partitionBy).mode(mode) + self.mode(mode) + if partitionBy is not None: + self.partitionBy(partitionBy) self._jwrite.parquet(path) @since(1.4) From 630bd5fd80193ab6dc6ad0e7bcc13ee0dadabd38 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 30 Jun 2015 00:46:55 +0900 Subject: [PATCH 041/122] [SPARK-8702] [WEBUI] Avoid massive concating strings in Javascript When there are massive tasks, such as `sc.parallelize(1 to 100000, 10000).count()`, the generated JS codes have a lot of string concatenations in the stage page, nearly 40 string concatenations for one task. We can generate the whole string for a task instead of execution string concatenations in the browser. Before this patch, the load time of the page is about 21 seconds. ![screen shot 2015-06-29 at 6 44 04 pm](https://cloud.githubusercontent.com/assets/1000778/8406644/eb55ed18-1e90-11e5-9ad5-50d27ad1dff1.png) After this patch, it reduces to about 17 seconds. ![screen shot 2015-06-29 at 6 47 34 pm](https://cloud.githubusercontent.com/assets/1000778/8406665/087003ca-1e91-11e5-80a8-3485aa9adafa.png) One disadvantage is that the generated JS codes become hard to read. Author: zsxwing Closes #7082 from zsxwing/js-string and squashes the following commits: b29231d [zsxwing] Avoid massive concating strings in Javascript --- .../org/apache/spark/ui/jobs/StagePage.scala | 88 +++++++++---------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index b83a49f79c8a8..e96bf49d0dd14 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -572,55 +572,55 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val attempt = taskInfo.attempt val timelineObject = s""" - { - 'className': 'task task-assignment-timeline-object', - 'group': '$executorId', - 'content': '
' + - 'Status: ${taskInfo.status}
' + - 'Launch Time: ${UIUtils.formatDate(new Date(launchTime))}' + - '${ + |{ + |'className': 'task task-assignment-timeline-object', + |'group': '$executorId', + |'content': '
+ |Status: ${taskInfo.status}
+ |Launch Time: ${UIUtils.formatDate(new Date(launchTime))} + |${ if (!taskInfo.running) { s"""
Finish Time: ${UIUtils.formatDate(new Date(finishTime))}""" } else { "" } - }' + - '
Scheduler Delay: $schedulerDelay ms' + - '
Task Deserialization Time: ${UIUtils.formatDuration(deserializationTime)}' + - '
Shuffle Read Time: ${UIUtils.formatDuration(shuffleReadTime)}' + - '
Executor Computing Time: ${UIUtils.formatDuration(executorComputingTime)}' + - '
Shuffle Write Time: ${UIUtils.formatDuration(shuffleWriteTime)}' + - '
Result Serialization Time: ${UIUtils.formatDuration(serializationTime)}' + - '
Getting Result Time: ${UIUtils.formatDuration(gettingResultTime)}">' + - '' + - '' + - '' + - '' + - '' + - '' + - '' + - '', - 'start': new Date($launchTime), - 'end': new Date($finishTime) - } - """ + } + |
Scheduler Delay: $schedulerDelay ms + |
Task Deserialization Time: ${UIUtils.formatDuration(deserializationTime)} + |
Shuffle Read Time: ${UIUtils.formatDuration(shuffleReadTime)} + |
Executor Computing Time: ${UIUtils.formatDuration(executorComputingTime)} + |
Shuffle Write Time: ${UIUtils.formatDuration(shuffleWriteTime)} + |
Result Serialization Time: ${UIUtils.formatDuration(serializationTime)} + |
Getting Result Time: ${UIUtils.formatDuration(gettingResultTime)}"> + | + | + | + | + | + | + | + |', + |'start': new Date($launchTime), + |'end': new Date($finishTime) + |} + |""".stripMargin.replaceAll("\n", " ") timelineObject }.mkString("[", ",", "]") From 5c796d576ec2de96bf72dbf6ccd0e85480a6e3b1 Mon Sep 17 00:00:00 2001 From: Brennon York Date: Mon, 29 Jun 2015 08:55:06 -0700 Subject: [PATCH 042/122] [SPARK-8693] [PROJECT INFRA] profiles and goals are not printed in a nice way Hotfix to correct formatting errors of print statements within the dev and jenkins builds. Error looks like: ``` -Phadoop-1[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: -Dhadoop.version=1.0.4[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: -Pkinesis-asl[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: -Phive-thriftserver[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: -Phive[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: package[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: assembly/assembly[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: streaming-kafka-assembly/assembly ``` Author: Brennon York Closes #7085 from brennonyork/SPARK-8693 and squashes the following commits: c5575f1 [Brennon York] added commas to end of print statements for proper printing --- dev/run-tests.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/dev/run-tests.py b/dev/run-tests.py index eb79a2a502707..e5c897b94d167 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -210,7 +210,7 @@ def build_spark_documentation(): jekyll_bin = which("jekyll") if not jekyll_bin: - print("[error] Cannot find a version of `jekyll` on the system; please" + print("[error] Cannot find a version of `jekyll` on the system; please", " install one and retry to build documentation.") sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) else: @@ -270,7 +270,7 @@ def get_hadoop_profiles(hadoop_version): if hadoop_version in sbt_maven_hadoop_profiles: return sbt_maven_hadoop_profiles[hadoop_version] else: - print("[error] Could not find", hadoop_version, "in the list. Valid options" + print("[error] Could not find", hadoop_version, "in the list. Valid options", " are", sbt_maven_hadoop_profiles.keys()) sys.exit(int(os.environ.get("CURRENT_BLOCK", 255))) @@ -281,7 +281,7 @@ def build_spark_maven(hadoop_version): mvn_goals = ["clean", "package", "-DskipTests"] profiles_and_goals = build_profiles + mvn_goals - print("[info] Building Spark (w/Hive 0.13.1) using Maven with these arguments: " + print("[info] Building Spark (w/Hive 0.13.1) using Maven with these arguments: ", " ".join(profiles_and_goals)) exec_maven(profiles_and_goals) @@ -295,7 +295,7 @@ def build_spark_sbt(hadoop_version): "streaming-kafka-assembly/assembly"] profiles_and_goals = build_profiles + sbt_goals - print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: " + print("[info] Building Spark (w/Hive 0.13.1) using SBT with these arguments: ", " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) @@ -324,7 +324,7 @@ def run_scala_tests_maven(test_profiles): mvn_test_goals = ["test", "--fail-at-end"] profiles_and_goals = test_profiles + mvn_test_goals - print("[info] Running Spark tests using Maven with these arguments: " + print("[info] Running Spark tests using Maven with these arguments: ", " ".join(profiles_and_goals)) exec_maven(profiles_and_goals) @@ -339,7 +339,7 @@ def run_scala_tests_sbt(test_modules, test_profiles): profiles_and_goals = test_profiles + list(sbt_test_goals) - print("[info] Running Spark tests using SBT with these arguments: " + print("[info] Running Spark tests using SBT with these arguments: ", " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) @@ -382,7 +382,7 @@ def run_sparkr_tests(): def main(): # Ensure the user home directory (HOME) is valid and is an absolute directory if not USER_HOME or not os.path.isabs(USER_HOME): - print("[error] Cannot determine your home directory as an absolute path;" + print("[error] Cannot determine your home directory as an absolute path;", " ensure the $HOME environment variable is set properly.") sys.exit(1) @@ -397,7 +397,7 @@ def main(): java_exe = determine_java_executable() if not java_exe: - print("[error] Cannot find a version of `java` on the system; please" + print("[error] Cannot find a version of `java` on the system; please", " install one and retry.") sys.exit(2) From 715f084ca08ad48174ab19a699a0ac77f80b68cd Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Mon, 29 Jun 2015 09:22:55 -0700 Subject: [PATCH 043/122] [SPARK-8554] Add the SparkR document files to `.rat-excludes` for `./dev/check-license` [[SPARK-8554] Add the SparkR document files to `.rat-excludes` for `./dev/check-license` - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-8554) Author: Yu ISHIKAWA Closes #6947 from yu-iskw/SPARK-8554 and squashes the following commits: 5ca240c [Yu ISHIKAWA] [SPARK-8554] Add the SparkR document files to `.rat-excludes` for `./dev/check-license` --- .rat-excludes | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.rat-excludes b/.rat-excludes index c24667c18dbda..0240e81c45ea2 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -86,4 +86,8 @@ local-1430917381535_2 DESCRIPTION NAMESPACE test_support/* +.*Rd +help/* +html/* +INDEX .lintr From ea88b1a5077e6ba980b0de6d3bc508c62285ba4c Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 29 Jun 2015 10:52:05 -0700 Subject: [PATCH 044/122] Revert "[SPARK-8372] History server shows incorrect information for application not started" This reverts commit 2837e067099921dd4ab6639ac5f6e89f789d4ff4. --- .../deploy/history/FsHistoryProvider.scala | 38 +++++++--------- .../history/FsHistoryProviderSuite.scala | 43 ++++++------------- 2 files changed, 28 insertions(+), 53 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index db383b9823d3c..5427a88f32ffd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -160,7 +160,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) replayBus.addListener(appListener) val appInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), replayBus) - appInfo.foreach { app => ui.setAppName(s"${app.name} ($appId)") } + ui.setAppName(s"${appInfo.name} ($appId)") val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) ui.getSecurityManager.setAcls(uiAclsEnabled) @@ -282,12 +282,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val newAttempts = logs.flatMap { fileStatus => try { val res = replay(fileStatus, bus) - res match { - case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully.") - case None => logWarning(s"Failed to load application log ${fileStatus.getPath}. " + - "The application may have not started.") - } - res + logInfo(s"Application log ${res.logPath} loaded successfully.") + Some(res) } catch { case e: Exception => logError( @@ -433,11 +429,9 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replays the events in the specified log file and returns information about the associated - * application. Return `None` if the application ID cannot be located. + * application. */ - private def replay( - eventLog: FileStatus, - bus: ReplayListenerBus): Option[FsApplicationAttemptInfo] = { + private def replay(eventLog: FileStatus, bus: ReplayListenerBus): FsApplicationAttemptInfo = { val logPath = eventLog.getPath() logInfo(s"Replaying log path: $logPath") val logInput = @@ -451,18 +445,16 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val appCompleted = isApplicationCompleted(eventLog) bus.addListener(appListener) bus.replay(logInput, logPath.toString, !appCompleted) - appListener.appId.map { appId => - new FsApplicationAttemptInfo( - logPath.getName(), - appListener.appName.getOrElse(NOT_STARTED), - appId, - appListener.appAttemptId, - appListener.startTime.getOrElse(-1L), - appListener.endTime.getOrElse(-1L), - getModificationTime(eventLog).get, - appListener.sparkUser.getOrElse(NOT_STARTED), - appCompleted) - } + new FsApplicationAttemptInfo( + logPath.getName(), + appListener.appName.getOrElse(NOT_STARTED), + appListener.appId.getOrElse(logPath.getName()), + appListener.appAttemptId, + appListener.startTime.getOrElse(-1L), + appListener.endTime.getOrElse(-1L), + getModificationTime(eventLog).get, + appListener.sparkUser.getOrElse(NOT_STARTED), + appCompleted) } finally { logInput.close() } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index d3a6db5f260d6..09075eeb539aa 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -67,8 +67,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Write a new-style application log. val newAppComplete = newLogFile("new1", None, inProgress = false) writeFile(newAppComplete, true, None, - SparkListenerApplicationStart( - "new-app-complete", Some("new-app-complete"), 1L, "test", None), + SparkListenerApplicationStart("new-app-complete", None, 1L, "test", None), SparkListenerApplicationEnd(5L) ) @@ -76,15 +75,13 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val newAppCompressedComplete = newLogFile("new1compressed", None, inProgress = false, Some("lzf")) writeFile(newAppCompressedComplete, true, None, - SparkListenerApplicationStart( - "new-app-compressed-complete", Some("new-app-compressed-complete"), 1L, "test", None), + SparkListenerApplicationStart("new-app-compressed-complete", None, 1L, "test", None), SparkListenerApplicationEnd(4L)) // Write an unfinished app, new-style. val newAppIncomplete = newLogFile("new2", None, inProgress = true) writeFile(newAppIncomplete, true, None, - SparkListenerApplicationStart( - "new-app-incomplete", Some("new-app-incomplete"), 1L, "test", None) + SparkListenerApplicationStart("new-app-incomplete", None, 1L, "test", None) ) // Write an old-style application log. @@ -92,8 +89,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc oldAppComplete.mkdir() createEmptyFile(new File(oldAppComplete, provider.SPARK_VERSION_PREFIX + "1.0")) writeFile(new File(oldAppComplete, provider.LOG_PREFIX + "1"), false, None, - SparkListenerApplicationStart( - "old-app-complete", Some("old-app-complete"), 2L, "test", None), + SparkListenerApplicationStart("old-app-complete", None, 2L, "test", None), SparkListenerApplicationEnd(3L) ) createEmptyFile(new File(oldAppComplete, provider.APPLICATION_COMPLETE)) @@ -107,8 +103,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc oldAppIncomplete.mkdir() createEmptyFile(new File(oldAppIncomplete, provider.SPARK_VERSION_PREFIX + "1.0")) writeFile(new File(oldAppIncomplete, provider.LOG_PREFIX + "1"), false, None, - SparkListenerApplicationStart( - "old-app-incomplete", Some("old-app-incomplete"), 2L, "test", None) + SparkListenerApplicationStart("old-app-incomplete", None, 2L, "test", None) ) // Force a reload of data from the log directory, and check that both logs are loaded. @@ -129,16 +124,16 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc List(ApplicationAttemptInfo(None, start, end, lastMod, user, completed))) } - list(0) should be (makeAppInfo("new-app-complete", "new-app-complete", 1L, 5L, + list(0) should be (makeAppInfo(newAppComplete.getName(), "new-app-complete", 1L, 5L, newAppComplete.lastModified(), "test", true)) - list(1) should be (makeAppInfo("new-app-compressed-complete", + list(1) should be (makeAppInfo(newAppCompressedComplete.getName(), "new-app-compressed-complete", 1L, 4L, newAppCompressedComplete.lastModified(), "test", true)) - list(2) should be (makeAppInfo("old-app-complete", "old-app-complete", 2L, 3L, + list(2) should be (makeAppInfo(oldAppComplete.getName(), "old-app-complete", 2L, 3L, oldAppComplete.lastModified(), "test", true)) - list(3) should be (makeAppInfo("old-app-incomplete", "old-app-incomplete", 2L, -1L, + list(3) should be (makeAppInfo(oldAppIncomplete.getName(), "old-app-incomplete", 2L, -1L, oldAppIncomplete.lastModified(), "test", false)) - list(4) should be (makeAppInfo("new-app-incomplete", "new-app-incomplete", 1L, -1L, + list(4) should be (makeAppInfo(newAppIncomplete.getName(), "new-app-incomplete", 1L, -1L, newAppIncomplete.lastModified(), "test", false)) // Make sure the UI can be rendered. @@ -162,7 +157,7 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc logDir.mkdir() createEmptyFile(new File(logDir, provider.SPARK_VERSION_PREFIX + "1.0")) writeFile(new File(logDir, provider.LOG_PREFIX + "1"), false, Option(codec), - SparkListenerApplicationStart("app2", Some("app2"), 2L, "test", None), + SparkListenerApplicationStart("app2", None, 2L, "test", None), SparkListenerApplicationEnd(3L) ) createEmptyFile(new File(logDir, provider.COMPRESSION_CODEC_PREFIX + codecName)) @@ -185,12 +180,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc test("SPARK-3697: ignore directories that cannot be read.") { val logFile1 = newLogFile("new1", None, inProgress = false) writeFile(logFile1, true, None, - SparkListenerApplicationStart("app1-1", Some("app1-1"), 1L, "test", None), + SparkListenerApplicationStart("app1-1", None, 1L, "test", None), SparkListenerApplicationEnd(2L) ) val logFile2 = newLogFile("new2", None, inProgress = false) writeFile(logFile2, true, None, - SparkListenerApplicationStart("app1-2", Some("app1-2"), 1L, "test", None), + SparkListenerApplicationStart("app1-2", None, 1L, "test", None), SparkListenerApplicationEnd(2L) ) logFile2.setReadable(false, false) @@ -223,18 +218,6 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } - test("Parse logs that application is not started") { - val provider = new FsHistoryProvider((createTestConf())) - - val logFile1 = newLogFile("app1", None, inProgress = true) - writeFile(logFile1, true, None, - SparkListenerLogStart("1.4") - ) - updateAndCheck(provider) { list => - list.size should be (0) - } - } - test("SPARK-5582: empty log directory") { val provider = new FsHistoryProvider(createTestConf()) From ed413bcc78d8d97a1a0cd0871d7a20f7170476d0 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 29 Jun 2015 11:41:26 -0700 Subject: [PATCH 045/122] [SPARK-8692] [SQL] re-order the case statements that handling catalyst data types use same order: boolean, byte, short, int, date, long, timestamp, float, double, string, binary, decimal. Then we can easily check whether some data types are missing by just one glance, and make sure we handle data/timestamp just as int/long. Author: Wenchen Fan Closes #7073 from cloud-fan/fix-date and squashes the following commits: 463044d [Wenchen Fan] fix style 51cd347 [Wenchen Fan] refactor handling of date and timestmap --- .../expressions/SpecificMutableRow.scala | 12 +-- .../expressions/UnsafeRowConverter.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 6 +- .../spark/sql/columnar/ColumnAccessor.scala | 42 +++++----- .../spark/sql/columnar/ColumnBuilder.scala | 30 +++---- .../spark/sql/columnar/ColumnStats.scala | 74 ++++++++--------- .../spark/sql/columnar/ColumnType.scala | 10 +-- .../sql/execution/SparkSqlSerializer2.scala | 82 ++++++------------- .../sql/parquet/ParquetTableSupport.scala | 34 ++++---- .../spark/sql/parquet/ParquetTypes.scala | 4 +- .../spark/sql/columnar/ColumnStatsSuite.scala | 9 +- .../spark/sql/columnar/ColumnTypeSuite.scala | 54 ++++++------ .../sql/columnar/ColumnarTestUtils.scala | 8 +- .../NullableColumnAccessorSuite.scala | 6 +- .../columnar/NullableColumnBuilderSuite.scala | 6 +- 15 files changed, 174 insertions(+), 209 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 53fedb531cfb2..3928c0f2ffdaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -196,15 +196,15 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR def this(dataTypes: Seq[DataType]) = this( dataTypes.map { - case IntegerType => new MutableInt + case BooleanType => new MutableBoolean case ByteType => new MutableByte - case FloatType => new MutableFloat case ShortType => new MutableShort + // We use INT for DATE internally + case IntegerType | DateType => new MutableInt + // We use Long for Timestamp internally + case LongType | TimestampType => new MutableLong + case FloatType => new MutableFloat case DoubleType => new MutableDouble - case BooleanType => new MutableBoolean - case LongType => new MutableLong - case DateType => new MutableInt // We use INT for DATE internally - case TimestampType => new MutableLong // We use Long for Timestamp internally case _ => new MutableAny }.toArray) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 89adaf053b1a4..b61d490429e4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -128,14 +128,12 @@ private object UnsafeColumnWriter { case BooleanType => BooleanUnsafeColumnWriter case ByteType => ByteUnsafeColumnWriter case ShortType => ShortUnsafeColumnWriter - case IntegerType => IntUnsafeColumnWriter - case LongType => LongUnsafeColumnWriter + case IntegerType | DateType => IntUnsafeColumnWriter + case LongType | TimestampType => LongUnsafeColumnWriter case FloatType => FloatUnsafeColumnWriter case DoubleType => DoubleUnsafeColumnWriter case StringType => StringUnsafeColumnWriter case BinaryType => BinaryUnsafeColumnWriter - case DateType => IntUnsafeColumnWriter - case TimestampType => LongUnsafeColumnWriter case t => throw new UnsupportedOperationException(s"Do not know how to write columns of type $t") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e20e3a9dca502..57e0bede5db20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -120,15 +120,13 @@ class CodeGenContext { case BooleanType => JAVA_BOOLEAN case ByteType => JAVA_BYTE case ShortType => JAVA_SHORT - case IntegerType => JAVA_INT - case LongType => JAVA_LONG + case IntegerType | DateType => JAVA_INT + case LongType | TimestampType => JAVA_LONG case FloatType => JAVA_FLOAT case DoubleType => JAVA_DOUBLE case dt: DecimalType => decimalType case BinaryType => "byte[]" case StringType => stringType - case DateType => JAVA_INT - case TimestampType => JAVA_LONG case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala index 64449b2659b4b..931469bed634a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnAccessor.scala @@ -71,44 +71,44 @@ private[sql] abstract class NativeColumnAccessor[T <: AtomicType]( private[sql] class BooleanColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, BOOLEAN) -private[sql] class IntColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, INT) +private[sql] class ByteColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, BYTE) private[sql] class ShortColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, SHORT) +private[sql] class IntColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, INT) + private[sql] class LongColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, LONG) -private[sql] class ByteColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, BYTE) - -private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, DOUBLE) - private[sql] class FloatColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, FLOAT) -private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int) - extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale)) +private[sql] class DoubleColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, DOUBLE) private[sql] class StringColumnAccessor(buffer: ByteBuffer) extends NativeColumnAccessor(buffer, STRING) -private[sql] class DateColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, DATE) - -private[sql] class TimestampColumnAccessor(buffer: ByteBuffer) - extends NativeColumnAccessor(buffer, TIMESTAMP) - private[sql] class BinaryColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[BinaryType.type, Array[Byte]](buffer, BINARY) with NullableColumnAccessor +private[sql] class FixedDecimalColumnAccessor(buffer: ByteBuffer, precision: Int, scale: Int) + extends NativeColumnAccessor(buffer, FIXED_DECIMAL(precision, scale)) + private[sql] class GenericColumnAccessor(buffer: ByteBuffer) extends BasicColumnAccessor[DataType, Array[Byte]](buffer, GENERIC) with NullableColumnAccessor +private[sql] class DateColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, DATE) + +private[sql] class TimestampColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, TIMESTAMP) + private[sql] object ColumnAccessor { def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { val dup = buffer.duplicate().order(ByteOrder.nativeOrder) @@ -118,17 +118,17 @@ private[sql] object ColumnAccessor { dup.getInt() dataType match { + case BooleanType => new BooleanColumnAccessor(dup) + case ByteType => new ByteColumnAccessor(dup) + case ShortType => new ShortColumnAccessor(dup) case IntegerType => new IntColumnAccessor(dup) + case DateType => new DateColumnAccessor(dup) case LongType => new LongColumnAccessor(dup) + case TimestampType => new TimestampColumnAccessor(dup) case FloatType => new FloatColumnAccessor(dup) case DoubleType => new DoubleColumnAccessor(dup) - case BooleanType => new BooleanColumnAccessor(dup) - case ByteType => new ByteColumnAccessor(dup) - case ShortType => new ShortColumnAccessor(dup) case StringType => new StringColumnAccessor(dup) case BinaryType => new BinaryColumnAccessor(dup) - case DateType => new DateColumnAccessor(dup) - case TimestampType => new TimestampColumnAccessor(dup) case DecimalType.Fixed(precision, scale) if precision < 19 => new FixedDecimalColumnAccessor(dup, precision, scale) case _ => new GenericColumnAccessor(dup) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala index 1949625699ca8..087c52239713d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnBuilder.scala @@ -94,17 +94,21 @@ private[sql] abstract class NativeColumnBuilder[T <: AtomicType]( private[sql] class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) -private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) +private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) private[sql] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) +private[sql] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) + private[sql] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) -private[sql] class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) +private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) private[sql] class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) -private[sql] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) +private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) + +private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) private[sql] class FixedDecimalColumnBuilder( precision: Int, @@ -113,19 +117,15 @@ private[sql] class FixedDecimalColumnBuilder( new FixedDecimalColumnStats, FIXED_DECIMAL(precision, scale)) -private[sql] class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) +// TODO (lian) Add support for array, struct and map +private[sql] class GenericColumnBuilder + extends ComplexColumnBuilder(new GenericColumnStats, GENERIC) private[sql] class DateColumnBuilder extends NativeColumnBuilder(new DateColumnStats, DATE) private[sql] class TimestampColumnBuilder extends NativeColumnBuilder(new TimestampColumnStats, TIMESTAMP) -private[sql] class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) - -// TODO (lian) Add support for array, struct and map -private[sql] class GenericColumnBuilder - extends ComplexColumnBuilder(new GenericColumnStats, GENERIC) - private[sql] object ColumnBuilder { val DEFAULT_INITIAL_BUFFER_SIZE = 1024 * 1024 @@ -151,17 +151,17 @@ private[sql] object ColumnBuilder { columnName: String = "", useCompression: Boolean = false): ColumnBuilder = { val builder: ColumnBuilder = dataType match { + case BooleanType => new BooleanColumnBuilder + case ByteType => new ByteColumnBuilder + case ShortType => new ShortColumnBuilder case IntegerType => new IntColumnBuilder + case DateType => new DateColumnBuilder case LongType => new LongColumnBuilder + case TimestampType => new TimestampColumnBuilder case FloatType => new FloatColumnBuilder case DoubleType => new DoubleColumnBuilder - case BooleanType => new BooleanColumnBuilder - case ByteType => new ByteColumnBuilder - case ShortType => new ShortColumnBuilder case StringType => new StringColumnBuilder case BinaryType => new BinaryColumnBuilder - case DateType => new DateColumnBuilder - case TimestampType => new TimestampColumnBuilder case DecimalType.Fixed(precision, scale) if precision < 19 => new FixedDecimalColumnBuilder(precision, scale) case _ => new GenericColumnBuilder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala index 1bce214d1d6c3..00374d1fa3ef1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnStats.scala @@ -132,17 +132,17 @@ private[sql] class ShortColumnStats extends ColumnStats { InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class LongColumnStats extends ColumnStats { - protected var upper = Long.MinValue - protected var lower = Long.MaxValue +private[sql] class IntColumnStats extends ColumnStats { + protected var upper = Int.MinValue + protected var lower = Int.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getLong(ordinal) + val value = row.getInt(ordinal) if (value > upper) upper = value if (value < lower) lower = value - sizeInBytes += LONG.defaultSize + sizeInBytes += INT.defaultSize } } @@ -150,17 +150,17 @@ private[sql] class LongColumnStats extends ColumnStats { InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class DoubleColumnStats extends ColumnStats { - protected var upper = Double.MinValue - protected var lower = Double.MaxValue +private[sql] class LongColumnStats extends ColumnStats { + protected var upper = Long.MinValue + protected var lower = Long.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getDouble(ordinal) + val value = row.getLong(ordinal) if (value > upper) upper = value if (value < lower) lower = value - sizeInBytes += DOUBLE.defaultSize + sizeInBytes += LONG.defaultSize } } @@ -186,35 +186,17 @@ private[sql] class FloatColumnStats extends ColumnStats { InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class FixedDecimalColumnStats extends ColumnStats { - protected var upper: Decimal = null - protected var lower: Decimal = null - - override def gatherStats(row: InternalRow, ordinal: Int): Unit = { - super.gatherStats(row, ordinal) - if (!row.isNullAt(ordinal)) { - val value = row(ordinal).asInstanceOf[Decimal] - if (upper == null || value.compareTo(upper) > 0) upper = value - if (lower == null || value.compareTo(lower) < 0) lower = value - sizeInBytes += FIXED_DECIMAL.defaultSize - } - } - - override def collectedStatistics: InternalRow = - InternalRow(lower, upper, nullCount, count, sizeInBytes) -} - -private[sql] class IntColumnStats extends ColumnStats { - protected var upper = Int.MinValue - protected var lower = Int.MaxValue +private[sql] class DoubleColumnStats extends ColumnStats { + protected var upper = Double.MinValue + protected var lower = Double.MaxValue override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) if (!row.isNullAt(ordinal)) { - val value = row.getInt(ordinal) + val value = row.getDouble(ordinal) if (value > upper) upper = value if (value < lower) lower = value - sizeInBytes += INT.defaultSize + sizeInBytes += DOUBLE.defaultSize } } @@ -240,10 +222,6 @@ private[sql] class StringColumnStats extends ColumnStats { InternalRow(lower, upper, nullCount, count, sizeInBytes) } -private[sql] class DateColumnStats extends IntColumnStats - -private[sql] class TimestampColumnStats extends LongColumnStats - private[sql] class BinaryColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) @@ -256,6 +234,24 @@ private[sql] class BinaryColumnStats extends ColumnStats { InternalRow(null, null, nullCount, count, sizeInBytes) } +private[sql] class FixedDecimalColumnStats extends ColumnStats { + protected var upper: Decimal = null + protected var lower: Decimal = null + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row(ordinal).asInstanceOf[Decimal] + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + sizeInBytes += FIXED_DECIMAL.defaultSize + } + } + + override def collectedStatistics: InternalRow = + InternalRow(lower, upper, nullCount, count, sizeInBytes) +} + private[sql] class GenericColumnStats extends ColumnStats { override def gatherStats(row: InternalRow, ordinal: Int): Unit = { super.gatherStats(row, ordinal) @@ -267,3 +263,7 @@ private[sql] class GenericColumnStats extends ColumnStats { override def collectedStatistics: InternalRow = InternalRow(null, null, nullCount, count, sizeInBytes) } + +private[sql] class DateColumnStats extends IntColumnStats + +private[sql] class TimestampColumnStats extends LongColumnStats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala index 8bf2151e4de68..fc72360c88fe1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/ColumnType.scala @@ -447,17 +447,17 @@ private[sql] object GENERIC extends ByteArrayColumnType[DataType](12, 16) { private[sql] object ColumnType { def apply(dataType: DataType): ColumnType[_, _] = { dataType match { + case BooleanType => BOOLEAN + case ByteType => BYTE + case ShortType => SHORT case IntegerType => INT + case DateType => DATE case LongType => LONG + case TimestampType => TIMESTAMP case FloatType => FLOAT case DoubleType => DOUBLE - case BooleanType => BOOLEAN - case ByteType => BYTE - case ShortType => SHORT case StringType => STRING case BinaryType => BINARY - case DateType => DATE - case TimestampType => TIMESTAMP case DecimalType.Fixed(precision, scale) if precision < 19 => FIXED_DECIMAL(precision, scale) case _ => GENERIC diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala index 74a22353b1d27..056d435eecd23 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer2.scala @@ -237,7 +237,7 @@ private[sql] object SparkSqlSerializer2 { out.writeShort(row.getShort(i)) } - case IntegerType => + case IntegerType | DateType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { @@ -245,7 +245,7 @@ private[sql] object SparkSqlSerializer2 { out.writeInt(row.getInt(i)) } - case LongType => + case LongType | TimestampType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { @@ -269,55 +269,39 @@ private[sql] object SparkSqlSerializer2 { out.writeDouble(row.getDouble(i)) } - case decimal: DecimalType => + case StringType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val value = row.apply(i).asInstanceOf[Decimal] - val javaBigDecimal = value.toJavaBigDecimal - // First, write out the unscaled value. - val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray + val bytes = row.getAs[UTF8String](i).getBytes out.writeInt(bytes.length) out.write(bytes) - // Then, write out the scale. - out.writeInt(javaBigDecimal.scale()) } - case DateType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeInt(row.getAs[Int](i)) - } - - case TimestampType => - if (row.isNullAt(i)) { - out.writeByte(NULL) - } else { - out.writeByte(NOT_NULL) - out.writeLong(row.getAs[Long](i)) - } - - case StringType => + case BinaryType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val bytes = row.getAs[UTF8String](i).getBytes + val bytes = row.getAs[Array[Byte]](i) out.writeInt(bytes.length) out.write(bytes) } - case BinaryType => + case decimal: DecimalType => if (row.isNullAt(i)) { out.writeByte(NULL) } else { out.writeByte(NOT_NULL) - val bytes = row.getAs[Array[Byte]](i) + val value = row.apply(i).asInstanceOf[Decimal] + val javaBigDecimal = value.toJavaBigDecimal + // First, write out the unscaled value. + val bytes: Array[Byte] = javaBigDecimal.unscaledValue().toByteArray out.writeInt(bytes.length) out.write(bytes) + // Then, write out the scale. + out.writeInt(javaBigDecimal.scale()) } } i += 1 @@ -364,14 +348,14 @@ private[sql] object SparkSqlSerializer2 { mutableRow.setShort(i, in.readShort()) } - case IntegerType => + case IntegerType | DateType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { mutableRow.setInt(i, in.readInt()) } - case LongType => + case LongType | TimestampType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { @@ -392,53 +376,39 @@ private[sql] object SparkSqlSerializer2 { mutableRow.setDouble(i, in.readDouble()) } - case decimal: DecimalType => + case StringType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { - // First, read in the unscaled value. val length = in.readInt() val bytes = new Array[Byte](length) in.readFully(bytes) - val unscaledVal = new BigInteger(bytes) - // Then, read the scale. - val scale = in.readInt() - // Finally, create the Decimal object and set it in the row. - mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale))) - } - - case DateType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.update(i, in.readInt()) - } - - case TimestampType => - if (in.readByte() == NULL) { - mutableRow.setNullAt(i) - } else { - mutableRow.update(i, in.readLong()) + mutableRow.update(i, UTF8String.fromBytes(bytes)) } - case StringType => + case BinaryType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { val length = in.readInt() val bytes = new Array[Byte](length) in.readFully(bytes) - mutableRow.update(i, UTF8String.fromBytes(bytes)) + mutableRow.update(i, bytes) } - case BinaryType => + case decimal: DecimalType => if (in.readByte() == NULL) { mutableRow.setNullAt(i) } else { + // First, read in the unscaled value. val length = in.readInt() val bytes = new Array[Byte](length) in.readFully(bytes) - mutableRow.update(i, bytes) + val unscaledVal = new BigInteger(bytes) + // Then, read the scale. + val scale = in.readInt() + // Finally, create the Decimal object and set it in the row. + mutableRow.update(i, Decimal(new BigDecimal(unscaledVal, scale))) } } i += 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index 0d96a1e8070b1..df2a96dfeb619 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -198,19 +198,18 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo private[parquet] def writePrimitive(schema: DataType, value: Any): Unit = { if (value != null) { schema match { - case StringType => writer.addBinary( - Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) - case BinaryType => writer.addBinary( - Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) - case IntegerType => writer.addInteger(value.asInstanceOf[Int]) + case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) + case ByteType => writer.addInteger(value.asInstanceOf[Byte]) case ShortType => writer.addInteger(value.asInstanceOf[Short]) + case IntegerType | DateType => writer.addInteger(value.asInstanceOf[Int]) case LongType => writer.addLong(value.asInstanceOf[Long]) case TimestampType => writeTimestamp(value.asInstanceOf[Long]) - case ByteType => writer.addInteger(value.asInstanceOf[Byte]) - case DoubleType => writer.addDouble(value.asInstanceOf[Double]) case FloatType => writer.addFloat(value.asInstanceOf[Float]) - case BooleanType => writer.addBoolean(value.asInstanceOf[Boolean]) - case DateType => writer.addInteger(value.asInstanceOf[Int]) + case DoubleType => writer.addDouble(value.asInstanceOf[Double]) + case StringType => writer.addBinary( + Binary.fromByteArray(value.asInstanceOf[UTF8String].getBytes)) + case BinaryType => writer.addBinary( + Binary.fromByteArray(value.asInstanceOf[Array[Byte]])) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") @@ -353,19 +352,18 @@ private[parquet] class MutableRowWriteSupport extends RowWriteSupport { record: InternalRow, index: Int): Unit = { ctype match { + case BooleanType => writer.addBoolean(record.getBoolean(index)) + case ByteType => writer.addInteger(record.getByte(index)) + case ShortType => writer.addInteger(record.getShort(index)) + case IntegerType | DateType => writer.addInteger(record.getInt(index)) + case LongType => writer.addLong(record.getLong(index)) + case TimestampType => writeTimestamp(record.getLong(index)) + case FloatType => writer.addFloat(record.getFloat(index)) + case DoubleType => writer.addDouble(record.getDouble(index)) case StringType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[UTF8String].getBytes)) case BinaryType => writer.addBinary( Binary.fromByteArray(record(index).asInstanceOf[Array[Byte]])) - case IntegerType => writer.addInteger(record.getInt(index)) - case ShortType => writer.addInteger(record.getShort(index)) - case LongType => writer.addLong(record.getLong(index)) - case ByteType => writer.addInteger(record.getByte(index)) - case DoubleType => writer.addDouble(record.getDouble(index)) - case FloatType => writer.addFloat(record.getFloat(index)) - case BooleanType => writer.addBoolean(record.getBoolean(index)) - case DateType => writer.addInteger(record.getInt(index)) - case TimestampType => writeTimestamp(record.getLong(index)) case d: DecimalType => if (d.precisionInfo == None || d.precisionInfo.get.precision > 18) { sys.error(s"Unsupported datatype $d, cannot write to consumer") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala index 4d5199a140344..e748bd7857bd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTypes.scala @@ -38,8 +38,8 @@ import org.apache.spark.sql.types._ private[parquet] object ParquetTypesConverter extends Logging { def isPrimitiveType(ctype: DataType): Boolean = ctype match { - case _: NumericType | BooleanType | StringType | BinaryType => true - case _: DataType => false + case _: NumericType | BooleanType | DateType | TimestampType | StringType | BinaryType => true + case _ => false } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala index 1f37455dd0bc4..9bd7b221e93f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnStatsSuite.scala @@ -22,19 +22,20 @@ import org.apache.spark.sql.catalyst.expressions.InternalRow import org.apache.spark.sql.types._ class ColumnStatsSuite extends SparkFunSuite { + testColumnStats(classOf[BooleanColumnStats], BOOLEAN, InternalRow(true, false, 0)) testColumnStats(classOf[ByteColumnStats], BYTE, InternalRow(Byte.MaxValue, Byte.MinValue, 0)) testColumnStats(classOf[ShortColumnStats], SHORT, InternalRow(Short.MaxValue, Short.MinValue, 0)) testColumnStats(classOf[IntColumnStats], INT, InternalRow(Int.MaxValue, Int.MinValue, 0)) + testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0)) testColumnStats(classOf[LongColumnStats], LONG, InternalRow(Long.MaxValue, Long.MinValue, 0)) + testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, + InternalRow(Long.MaxValue, Long.MinValue, 0)) testColumnStats(classOf[FloatColumnStats], FLOAT, InternalRow(Float.MaxValue, Float.MinValue, 0)) testColumnStats(classOf[DoubleColumnStats], DOUBLE, InternalRow(Double.MaxValue, Double.MinValue, 0)) + testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) testColumnStats(classOf[FixedDecimalColumnStats], FIXED_DECIMAL(15, 10), InternalRow(null, null, 0)) - testColumnStats(classOf[StringColumnStats], STRING, InternalRow(null, null, 0)) - testColumnStats(classOf[DateColumnStats], DATE, InternalRow(Int.MaxValue, Int.MinValue, 0)) - testColumnStats(classOf[TimestampColumnStats], TIMESTAMP, - InternalRow(Long.MaxValue, Long.MinValue, 0)) def testColumnStats[T <: AtomicType, U <: ColumnStats]( columnStatsClass: Class[U], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala index 6daddfb2c4804..4d46a657056e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnTypeSuite.scala @@ -36,9 +36,9 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { test("defaultSize") { val checks = Map( - INT -> 4, SHORT -> 2, LONG -> 8, BYTE -> 1, DOUBLE -> 8, FLOAT -> 4, - FIXED_DECIMAL(15, 10) -> 8, BOOLEAN -> 1, STRING -> 8, DATE -> 4, TIMESTAMP -> 8, - BINARY -> 16, GENERIC -> 16) + BOOLEAN -> 1, BYTE -> 1, SHORT -> 2, INT -> 4, DATE -> 4, + LONG -> 8, TIMESTAMP -> 8, FLOAT -> 4, DOUBLE -> 8, + STRING -> 8, BINARY -> 16, FIXED_DECIMAL(15, 10) -> 8, GENERIC -> 16) checks.foreach { case (columnType, expectedSize) => assertResult(expectedSize, s"Wrong defaultSize for $columnType") { @@ -60,27 +60,24 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } } - checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(BOOLEAN, true, 1) + checkActualSize(BYTE, Byte.MaxValue, 1) checkActualSize(SHORT, Short.MaxValue, 2) + checkActualSize(INT, Int.MaxValue, 4) + checkActualSize(DATE, Int.MaxValue, 4) checkActualSize(LONG, Long.MaxValue, 8) - checkActualSize(BYTE, Byte.MaxValue, 1) - checkActualSize(DOUBLE, Double.MaxValue, 8) + checkActualSize(TIMESTAMP, Long.MaxValue, 8) checkActualSize(FLOAT, Float.MaxValue, 4) - checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) - checkActualSize(BOOLEAN, true, 1) + checkActualSize(DOUBLE, Double.MaxValue, 8) checkActualSize(STRING, UTF8String.fromString("hello"), 4 + "hello".getBytes("utf-8").length) - checkActualSize(DATE, 0, 4) - checkActualSize(TIMESTAMP, 0L, 8) - - val binary = Array.fill[Byte](4)(0: Byte) - checkActualSize(BINARY, binary, 4 + 4) + checkActualSize(BINARY, Array.fill[Byte](4)(0.toByte), 4 + 4) + checkActualSize(FIXED_DECIMAL(15, 10), Decimal(0, 15, 10), 8) val generic = Map(1 -> "a") checkActualSize(GENERIC, SparkSqlSerializer.serialize(generic), 4 + 8) } - testNativeColumnType[BooleanType.type]( - BOOLEAN, + testNativeColumnType(BOOLEAN)( (buffer: ByteBuffer, v: Boolean) => { buffer.put((if (v) 1 else 0).toByte) }, @@ -88,18 +85,23 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { buffer.get() == 1 }) - testNativeColumnType[IntegerType.type](INT, _.putInt(_), _.getInt) + testNativeColumnType(BYTE)(_.put(_), _.get) + + testNativeColumnType(SHORT)(_.putShort(_), _.getShort) + + testNativeColumnType(INT)(_.putInt(_), _.getInt) + + testNativeColumnType(DATE)(_.putInt(_), _.getInt) - testNativeColumnType[ShortType.type](SHORT, _.putShort(_), _.getShort) + testNativeColumnType(LONG)(_.putLong(_), _.getLong) - testNativeColumnType[LongType.type](LONG, _.putLong(_), _.getLong) + testNativeColumnType(TIMESTAMP)(_.putLong(_), _.getLong) - testNativeColumnType[ByteType.type](BYTE, _.put(_), _.get) + testNativeColumnType(FLOAT)(_.putFloat(_), _.getFloat) - testNativeColumnType[DoubleType.type](DOUBLE, _.putDouble(_), _.getDouble) + testNativeColumnType(DOUBLE)(_.putDouble(_), _.getDouble) - testNativeColumnType[DecimalType]( - FIXED_DECIMAL(15, 10), + testNativeColumnType(FIXED_DECIMAL(15, 10))( (buffer: ByteBuffer, decimal: Decimal) => { buffer.putLong(decimal.toUnscaledLong) }, @@ -107,10 +109,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { Decimal(buffer.getLong(), 15, 10) }) - testNativeColumnType[FloatType.type](FLOAT, _.putFloat(_), _.getFloat) - testNativeColumnType[StringType.type]( - STRING, + testNativeColumnType(STRING)( (buffer: ByteBuffer, string: UTF8String) => { val bytes = string.getBytes buffer.putInt(bytes.length) @@ -197,8 +197,8 @@ class ColumnTypeSuite extends SparkFunSuite with Logging { } def testNativeColumnType[T <: AtomicType]( - columnType: NativeColumnType[T], - putter: (ByteBuffer, T#InternalType) => Unit, + columnType: NativeColumnType[T]) + (putter: (ByteBuffer, T#InternalType) => Unit, getter: (ByteBuffer) => T#InternalType): Unit = { testColumnType[T, T#InternalType](columnType, putter, getter) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala index 7c86eae3f77fd..d9861339739c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/ColumnarTestUtils.scala @@ -39,18 +39,18 @@ object ColumnarTestUtils { } (columnType match { + case BOOLEAN => Random.nextBoolean() case BYTE => (Random.nextInt(Byte.MaxValue * 2) - Byte.MaxValue).toByte case SHORT => (Random.nextInt(Short.MaxValue * 2) - Short.MaxValue).toShort case INT => Random.nextInt() + case DATE => Random.nextInt() case LONG => Random.nextLong() + case TIMESTAMP => Random.nextLong() case FLOAT => Random.nextFloat() case DOUBLE => Random.nextDouble() - case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case STRING => UTF8String.fromString(Random.nextString(Random.nextInt(32))) - case BOOLEAN => Random.nextBoolean() case BINARY => randomBytes(Random.nextInt(32)) - case DATE => Random.nextInt() - case TIMESTAMP => Random.nextLong() + case FIXED_DECIMAL(precision, scale) => Decimal(Random.nextLong() % 100, precision, scale) case _ => // Using a random one-element map instead of an arbitrary object Map(Random.nextInt() -> Random.nextString(Random.nextInt(32))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala index 2a6e0c376551a..9eaa769846088 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnAccessorSuite.scala @@ -42,9 +42,9 @@ class NullableColumnAccessorSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( - INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC, - DATE, TIMESTAMP - ).foreach { + BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC) + .foreach { testNullableColumnAccessor(_) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala index cb4e9f1eb7f46..17e9ae464bcc0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/NullableColumnBuilderSuite.scala @@ -38,9 +38,9 @@ class NullableColumnBuilderSuite extends SparkFunSuite { import ColumnarTestUtils._ Seq( - INT, LONG, SHORT, BOOLEAN, BYTE, STRING, DOUBLE, FLOAT, FIXED_DECIMAL(15, 10), BINARY, GENERIC, - DATE, TIMESTAMP - ).foreach { + BOOLEAN, BYTE, SHORT, INT, DATE, LONG, TIMESTAMP, FLOAT, DOUBLE, + STRING, BINARY, FIXED_DECIMAL(15, 10), GENERIC) + .foreach { testNullableColumnBuilder(_) } From 3664ee25f0a67de5ba76e9487a55a55216ae589f Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 29 Jun 2015 11:53:17 -0700 Subject: [PATCH 046/122] [SPARK-8066, SPARK-8067] [hive] Add support for Hive 1.0, 1.1 and 1.2. Allow HiveContext to connect to metastores of those versions; some new shims had to be added to account for changing internal APIs. A new test was added to exercise the "reset()" path which now also requires a shim; and the test code was changed to use a directory under the build's target to store ivy dependencies. Without that, at least I consistently run into issues with Ivy messing up (or being confused) by my existing caches. Author: Marcelo Vanzin Closes #7026 from vanzin/SPARK-8067 and squashes the following commits: 3e2e67b [Marcelo Vanzin] [SPARK-8066, SPARK-8067] [hive] Add support for Hive 1.0, 1.1 and 1.2. --- .../spark/sql/hive/client/ClientWrapper.scala | 5 +- .../spark/sql/hive/client/HiveShim.scala | 70 ++++++++++++++++++- .../hive/client/IsolatedClientLoader.scala | 13 ++-- .../spark/sql/hive/client/package.scala | 33 +++++++-- .../spark/sql/hive/client/VersionsSuite.scala | 25 +++++-- 5 files changed, 131 insertions(+), 15 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 2f771d76793e5..4c708cec572ae 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -97,6 +97,9 @@ private[hive] class ClientWrapper( case hive.v12 => new Shim_v0_12() case hive.v13 => new Shim_v0_13() case hive.v14 => new Shim_v0_14() + case hive.v1_0 => new Shim_v1_0() + case hive.v1_1 => new Shim_v1_1() + case hive.v1_2 => new Shim_v1_2() } // Create an internal session state for this ClientWrapper. @@ -456,7 +459,7 @@ private[hive] class ClientWrapper( logDebug(s"Deleting table $t") val table = client.getTable("default", t) client.getIndexes("default", t, 255).foreach { index => - client.dropIndex("default", t, index.getIndexName, true) + shim.dropIndex(client, "default", t, index.getIndexName) } if (!table.isIndexTable) { client.dropTable("default", t) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index e7c1779f80ce6..1fa9d278e2a57 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.client -import java.lang.{Boolean => JBoolean, Integer => JInteger} +import java.lang.{Boolean => JBoolean, Integer => JInteger, Long => JLong} import java.lang.reflect.{Method, Modifier} import java.net.URI import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JSet} @@ -94,6 +94,8 @@ private[client] sealed abstract class Shim { holdDDLTime: Boolean, listBucketingEnabled: Boolean): Unit + def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit + protected def findStaticMethod(klass: Class[_], name: String, args: Class[_]*): Method = { val method = findMethod(klass, name, args: _*) require(Modifier.isStatic(method.getModifiers()), @@ -166,6 +168,14 @@ private[client] class Shim_v0_12 extends Shim { JInteger.TYPE, JBoolean.TYPE, JBoolean.TYPE) + private lazy val dropIndexMethod = + findMethod( + classOf[Hive], + "dropIndex", + classOf[String], + classOf[String], + classOf[String], + JBoolean.TYPE) override def setCurrentSessionState(state: SessionState): Unit = { // Starting from Hive 0.13, setCurrentSessionState will internally override @@ -234,6 +244,10 @@ private[client] class Shim_v0_12 extends Shim { numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean) } + override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { + dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean) + } + } private[client] class Shim_v0_13 extends Shim_v0_12 { @@ -379,3 +393,57 @@ private[client] class Shim_v0_14 extends Shim_v0_13 { TimeUnit.MILLISECONDS).asInstanceOf[Long] } } + +private[client] class Shim_v1_0 extends Shim_v0_14 { + +} + +private[client] class Shim_v1_1 extends Shim_v1_0 { + + private lazy val dropIndexMethod = + findMethod( + classOf[Hive], + "dropIndex", + classOf[String], + classOf[String], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE) + + override def dropIndex(hive: Hive, dbName: String, tableName: String, indexName: String): Unit = { + dropIndexMethod.invoke(hive, dbName, tableName, indexName, true: JBoolean, true: JBoolean) + } + +} + +private[client] class Shim_v1_2 extends Shim_v1_1 { + + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JLong.TYPE) + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + holdDDLTime: Boolean, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, holdDDLTime: JBoolean, listBucketingEnabled: JBoolean, JBoolean.FALSE, + 0: JLong) + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index 0934ad5034671..3d609a66f3664 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -41,9 +41,11 @@ private[hive] object IsolatedClientLoader { */ def forVersion( version: String, - config: Map[String, String] = Map.empty): IsolatedClientLoader = synchronized { + config: Map[String, String] = Map.empty, + ivyPath: Option[String] = None): IsolatedClientLoader = synchronized { val resolvedVersion = hiveVersion(version) - val files = resolvedVersions.getOrElseUpdate(resolvedVersion, downloadVersion(resolvedVersion)) + val files = resolvedVersions.getOrElseUpdate(resolvedVersion, + downloadVersion(resolvedVersion, ivyPath)) new IsolatedClientLoader(hiveVersion(version), files, config) } @@ -51,9 +53,12 @@ private[hive] object IsolatedClientLoader { case "12" | "0.12" | "0.12.0" => hive.v12 case "13" | "0.13" | "0.13.0" | "0.13.1" => hive.v13 case "14" | "0.14" | "0.14.0" => hive.v14 + case "1.0" | "1.0.0" => hive.v1_0 + case "1.1" | "1.1.0" => hive.v1_1 + case "1.2" | "1.2.0" => hive.v1_2 } - private def downloadVersion(version: HiveVersion): Seq[URL] = { + private def downloadVersion(version: HiveVersion, ivyPath: Option[String]): Seq[URL] = { val hiveArtifacts = version.extraDeps ++ Seq("hive-metastore", "hive-exec", "hive-common", "hive-serde") .map(a => s"org.apache.hive:$a:${version.fullVersion}") ++ @@ -64,7 +69,7 @@ private[hive] object IsolatedClientLoader { SparkSubmitUtils.resolveMavenCoordinates( hiveArtifacts.mkString(","), Some("http://www.datanucleus.org/downloads/maven2"), - None, + ivyPath, exclusions = version.exclusions) } val allFiles = classpath.split(",").map(new File(_)).toSet diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 27a3d8f5896cc..b48082fe4b363 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -32,13 +32,36 @@ package object client { // Hive 0.14 depends on calcite 0.9.2-incubating-SNAPSHOT which does not exist in // maven central anymore, so override those with a version that exists. // - // org.pentaho:pentaho-aggdesigner-algorithm is also nowhere to be found, so exclude - // it explicitly. If it's needed by the metastore client, users will have to dig it - // out of somewhere and use configuration to point Spark at the correct jars. + // The other excluded dependencies are also nowhere to be found, so exclude them explicitly. If + // they're needed by the metastore client, users will have to dig them out of somewhere and use + // configuration to point Spark at the correct jars. case object v14 extends HiveVersion("0.14.0", - Seq("org.apache.calcite:calcite-core:1.3.0-incubating", + extraDeps = Seq("org.apache.calcite:calcite-core:1.3.0-incubating", "org.apache.calcite:calcite-avatica:1.3.0-incubating"), - Seq("org.pentaho:pentaho-aggdesigner-algorithm")) + exclusions = Seq("org.pentaho:pentaho-aggdesigner-algorithm")) + + case object v1_0 extends HiveVersion("1.0.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) + + // The curator dependency was added to the exclusions here because it seems to confuse the ivy + // library. org.apache.curator:curator is a pom dependency but ivy tries to find the jar for it, + // and fails. + case object v1_1 extends HiveVersion("1.1.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) + + case object v1_2 extends HiveVersion("1.2.0", + exclusions = Seq("eigenbase:eigenbase-properties", + "org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm", + "net.hydromatic:linq4j", + "net.hydromatic:quidem")) } // scalastyle:on diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 9a571650b6e25..d52e162acbd04 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.client +import java.io.File + import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.util.Utils @@ -28,6 +30,12 @@ import org.apache.spark.util.Utils * is not fully tested. */ class VersionsSuite extends SparkFunSuite with Logging { + + // Do not use a temp path here to speed up subsequent executions of the unit test during + // development. + private val ivyPath = Some( + new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath()) + private def buildConf() = { lazy val warehousePath = Utils.createTempDir() lazy val metastorePath = Utils.createTempDir() @@ -38,7 +46,7 @@ class VersionsSuite extends SparkFunSuite with Logging { } test("success sanity check") { - val badClient = IsolatedClientLoader.forVersion("13", buildConf()).client + val badClient = IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client val db = new HiveDatabase("default", "") badClient.createDatabase(db) } @@ -67,19 +75,21 @@ class VersionsSuite extends SparkFunSuite with Logging { // TODO: currently only works on mysql where we manually create the schema... ignore("failure sanity check") { val e = intercept[Throwable] { - val badClient = quietly { IsolatedClientLoader.forVersion("13", buildConf()).client } + val badClient = quietly { + IsolatedClientLoader.forVersion("13", buildConf(), ivyPath).client + } } assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } - private val versions = Seq("12", "13", "14") + private val versions = Seq("12", "13", "14", "1.0.0", "1.1.0", "1.2.0") private var client: ClientInterface = null versions.foreach { version => test(s"$version: create client") { client = null - client = IsolatedClientLoader.forVersion(version, buildConf()).client + client = IsolatedClientLoader.forVersion(version, buildConf(), ivyPath).client } test(s"$version: createDatabase") { @@ -170,5 +180,12 @@ class VersionsSuite extends SparkFunSuite with Logging { false, false) } + + test(s"$version: create index and reset") { + client.runSqlHive("CREATE TABLE indexed_table (key INT)") + client.runSqlHive("CREATE INDEX index_1 ON TABLE indexed_table(key) " + + "as 'COMPACT' WITH DEFERRED REBUILD") + client.reset() + } } } From a5c2961caaafd751f11bdd406bb6885443d7572e Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 29 Jun 2015 11:57:19 -0700 Subject: [PATCH 047/122] [SPARK-8235] [SQL] misc function sha / sha1 Jira: https://issues.apache.org/jira/browse/SPARK-8235 I added the support for sha1. If I understood rxin correctly, sha and sha1 should execute the same algorithm, shouldn't they? Please take a close look on the Python part. This is adopted from #6934 Author: Tarek Auel Author: Tarek Auel Closes #6963 from tarekauel/SPARK-8235 and squashes the following commits: f064563 [Tarek Auel] change to shaHex 7ce3cdc [Tarek Auel] rely on automatic cast a1251d6 [Tarek Auel] Merge remote-tracking branch 'upstream/master' into SPARK-8235 68eb043 [Tarek Auel] added docstring be5aff1 [Tarek Auel] improved error message 7336c96 [Tarek Auel] added type check cf23a80 [Tarek Auel] simplified example ebf75ef [Tarek Auel] [SPARK-8301] updated the python documentation. Removed sha in python and scala 6d6ff0d [Tarek Auel] [SPARK-8233] added docstring ea191a9 [Tarek Auel] [SPARK-8233] fixed signatureof python function. Added expected type to misc e3fd7c3 [Tarek Auel] SPARK[8235] added sha to the list of __all__ e5dad4e [Tarek Auel] SPARK[8235] sha / sha1 --- python/pyspark/sql/functions.py | 14 +++++++++ .../catalyst/analysis/FunctionRegistry.scala | 2 ++ .../spark/sql/catalyst/expressions/misc.scala | 30 ++++++++++++++++++- .../expressions/MiscFunctionsSuite.scala | 8 +++++ .../org/apache/spark/sql/functions.scala | 16 ++++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 12 ++++++++ 6 files changed, 81 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 7d3d0361610b7..45ecd826bd3bd 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -42,6 +42,7 @@ 'monotonicallyIncreasingId', 'rand', 'randn', + 'sha1', 'sha2', 'sparkPartitionId', 'struct', @@ -382,6 +383,19 @@ def sha2(col, numBits): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def sha1(col): + """Returns the hex string result of SHA-1. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() + [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha1(_to_java_column(col)) + return Column(jc) + + @since(1.4) def sparkPartitionId(): """A column for partition ID of the Spark task. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 457948a800a17..b24064d061533 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -136,6 +136,8 @@ object FunctionRegistry { // misc functions expression[Md5]("md5"), expression[Sha2]("sha2"), + expression[Sha1]("sha1"), + expression[Sha1]("sha"), // aggregate functions expression[Average]("avg"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index e80706fc65aff..9a39165a1ff05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -21,8 +21,9 @@ import java.security.MessageDigest import java.security.NoSuchAlgorithmException import org.apache.commons.codec.digest.DigestUtils +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types.{BinaryType, IntegerType, StringType, DataType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -140,3 +141,30 @@ case class Sha2(left: Expression, right: Expression) """ } } + +/** + * A function that calculates a sha1 hash value and returns it as a hex string + * For input of type [[BinaryType]] or [[StringType]] + */ +case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes { + + override def dataType: DataType = StringType + + override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + null + } else { + UTF8String.fromString(DigestUtils.shaHex(value.asInstanceOf[Array[Byte]])) + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, c => + "org.apache.spark.unsafe.types.UTF8String.fromString" + + s"(org.apache.commons.codec.digest.DigestUtils.shaHex($c))" + ) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 38482c54c61db..36e636b5da6b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -31,6 +31,14 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Md5(Literal.create(null, BinaryType)), null) } + test("sha1") { + checkEvaluation(Sha1(Literal("ABC".getBytes)), "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8") + checkEvaluation(Sha1(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), + "5d211bad8f4ee70e16c7d343a838fc344a1ed961") + checkEvaluation(Sha1(Literal.create(null, BinaryType)), null) + checkEvaluation(Sha1(Literal("".getBytes)), "da39a3ee5e6b4b0d3255bfef95601890afd80709") + } + test("sha2") { checkEvaluation(Sha2(Literal("ABC".getBytes), Literal(256)), DigestUtils.sha256Hex("ABC")) checkEvaluation(Sha2(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType), Literal(384)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 355ce0e3423cf..ef92801548a13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1414,6 +1414,22 @@ object functions { */ def md5(columnName: String): Column = md5(Column(columnName)) + /** + * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha1(e: Column): Column = Sha1(e.expr) + + /** + * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * + * @group misc_funcs + * @since 1.5.0 + */ + def sha1(columnName: String): Column = sha1(Column(columnName)) + /** * Calculates the SHA-2 family of hash functions and returns the value as a hex string. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 8baed57a7f129..abfd47c811ed9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -144,6 +144,18 @@ class DataFrameFunctionsSuite extends QueryTest { Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) } + test("misc sha1 function") { + val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b") + checkAnswer( + df.select(sha1($"a"), sha1("b")), + Row("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8", "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")) + + val dfEmpty = Seq(("", "".getBytes)).toDF("a", "b") + checkAnswer( + dfEmpty.selectExpr("sha1(a)", "sha1(b)"), + Row("da39a3ee5e6b4b0d3255bfef95601890afd80709", "da39a3ee5e6b4b0d3255bfef95601890afd80709")) + } + test("misc sha2 function") { val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") checkAnswer( From 492dca3a73e70705b5d5639e8fe4640b80e78d31 Mon Sep 17 00:00:00 2001 From: Vladimir Vladimirov Date: Mon, 29 Jun 2015 12:03:41 -0700 Subject: [PATCH 048/122] [SPARK-8528] Expose SparkContext.applicationId in PySpark Use case - we want to log applicationId (YARN in hour case) to request help with troubleshooting from the DevOps Author: Vladimir Vladimirov Closes #6936 from smartkiwi/master and squashes the following commits: 870338b [Vladimir Vladimirov] this would make doctest to run in python3 0eae619 [Vladimir Vladimirov] Scala doesn't use u'...' for unicode literals 14d77a8 [Vladimir Vladimirov] stop using ELLIPSIS b4ebfc5 [Vladimir Vladimirov] addressed PR feedback - updated docstring 223a32f [Vladimir Vladimirov] fixed test - applicationId is property that returns the string 3221f5a [Vladimir Vladimirov] [SPARK-8528] added documentation for Scala 2cff090 [Vladimir Vladimirov] [SPARK-8528] add applicationId property for SparkContext object in pyspark --- .../scala/org/apache/spark/SparkContext.scala | 8 ++++++++ python/pyspark/context.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index c7a7436462083..b3c3bf3746e18 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -315,6 +315,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _dagScheduler = ds } + /** + * A unique identifier for the Spark application. + * Its format depends on the scheduler implementation. + * (i.e. + * in case of local spark app something like 'local-1433865536131' + * in case of YARN something like 'application_1433865536131_34483' + * ) + */ def applicationId: String = _applicationId def applicationAttemptId: Option[String] = _applicationAttemptId diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 90b2fffbb9c7c..d7466729b8f36 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -291,6 +291,21 @@ def version(self): """ return self._jsc.version() + @property + @ignore_unicode_prefix + def applicationId(self): + """ + A unique identifier for the Spark application. + Its format depends on the scheduler implementation. + (i.e. + in case of local spark app something like 'local-1433865536131' + in case of YARN something like 'application_1433865536131_34483' + ) + >>> sc.applicationId # doctest: +ELLIPSIS + u'local-...' + """ + return self._jsc.sc().applicationId() + @property def startTime(self): """Return the epoch time when the Spark Context was started.""" From 94e040d05996111b2b448bcdee1cda184c6d039b Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Mon, 29 Jun 2015 12:16:12 -0700 Subject: [PATCH 049/122] [SQL][DOCS] Remove wrong example from DataFrame.scala In DataFrame.scala, there are examples like as follows. ``` * // The following are equivalent: * peopleDf.filter($"age" > 15) * peopleDf.where($"age" > 15) * peopleDf($"age" > 15) ``` But, I think the last example doesn't work. Author: Kousuke Saruta Closes #6977 from sarutak/fix-dataframe-example and squashes the following commits: 46efbd7 [Kousuke Saruta] Removed wrong example --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d75d88307562e..986e59133919f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -682,7 +682,6 @@ class DataFrame private[sql]( * // The following are equivalent: * peopleDf.filter($"age" > 15) * peopleDf.where($"age" > 15) - * peopleDf($"age" > 15) * }}} * @group dfops * @since 1.3.0 @@ -707,7 +706,6 @@ class DataFrame private[sql]( * // The following are equivalent: * peopleDf.filter($"age" > 15) * peopleDf.where($"age" > 15) - * peopleDf($"age" > 15) * }}} * @group dfops * @since 1.3.0 From 637b4eedad84dcff1769454137a64ac70c7f2397 Mon Sep 17 00:00:00 2001 From: "zhichao.li" Date: Mon, 29 Jun 2015 12:25:16 -0700 Subject: [PATCH 050/122] [SPARK-8214] [SQL] Add function hex cc chenghao-intel adrian-wang Author: zhichao.li Closes #6976 from zhichao-li/hex and squashes the following commits: e218d1b [zhichao.li] turn off scalastyle for non-ascii de3f5ea [zhichao.li] non-ascii char cf9c936 [zhichao.li] give separated buffer for each hex method 967ec90 [zhichao.li] Make 'value' as a feild of Hex 3b2fa13 [zhichao.li] tiny fix a647641 [zhichao.li] remove duplicate null check 7cab020 [zhichao.li] tiny refactoring 35ecfe5 [zhichao.li] add function hex --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/math.scala | 86 ++++++++++++++++++- .../expressions/MathFunctionsSuite.scala | 14 ++- .../org/apache/spark/sql/functions.scala | 16 ++++ .../spark/sql/MathExpressionsSuite.scala | 13 +++ 5 files changed, 125 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b24064d061533..b17457d3094c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -113,6 +113,7 @@ object FunctionRegistry { expression[Expm1]("expm1"), expression[Floor]("floor"), expression[Hypot]("hypot"), + expression[Hex]("hex"), expression[Logarithm]("log"), expression[Log]("ln"), expression[Log10]("log10"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 5694afc61be05..4b57ddd9c5768 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import java.lang.{Long => JLong} +import java.util.Arrays +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StringType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -273,9 +275,6 @@ case class Atan2(left: Expression, right: Expression) } } -case class Hypot(left: Expression, right: Expression) - extends BinaryMathExpression(math.hypot, "HYPOT") - case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -287,6 +286,85 @@ case class Pow(left: Expression, right: Expression) } } +/** + * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. + * Otherwise if the number is a STRING, + * it converts each character into its hexadecimal representation and returns the resulting STRING. + * Negative numbers would be treated as two's complement. + */ +case class Hex(child: Expression) + extends UnaryExpression with Serializable { + + override def dataType: DataType = StringType + + override def checkInputDataTypes(): TypeCheckResult = { + if (child.dataType.isInstanceOf[StringType] + || child.dataType.isInstanceOf[IntegerType] + || child.dataType.isInstanceOf[LongType] + || child.dataType.isInstanceOf[BinaryType] + || child.dataType == NullType) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure(s"hex doesn't accepts ${child.dataType} type") + } + } + + override def eval(input: InternalRow): Any = { + val num = child.eval(input) + if (num == null) { + null + } else { + child.dataType match { + case LongType => hex(num.asInstanceOf[Long]) + case IntegerType => hex(num.asInstanceOf[Integer].toLong) + case BinaryType => hex(num.asInstanceOf[Array[Byte]]) + case StringType => hex(num.asInstanceOf[UTF8String]) + } + } + } + + /** + * Converts every character in s to two hex digits. + */ + private def hex(str: UTF8String): UTF8String = { + hex(str.getBytes) + } + + private def hex(bytes: Array[Byte]): UTF8String = { + doHex(bytes, bytes.length) + } + + private def doHex(bytes: Array[Byte], length: Int): UTF8String = { + val value = new Array[Byte](length * 2) + var i = 0 + while(i < length) { + value(i * 2) = Character.toUpperCase(Character.forDigit( + (bytes(i) & 0xF0) >>> 4, 16)).toByte + value(i * 2 + 1) = Character.toUpperCase(Character.forDigit( + bytes(i) & 0x0F, 16)).toByte + i += 1 + } + UTF8String.fromBytes(value) + } + + private def hex(num: Long): UTF8String = { + // Extract the hex digits of num into value[] from right to left + val value = new Array[Byte](16) + var numBuf = num + var len = 0 + do { + len += 1 + value(value.length - len) = Character.toUpperCase(Character + .forDigit((numBuf & 0xF).toInt, 16)).toByte + numBuf >>>= 4 + } while (numBuf != 0) + UTF8String.fromBytes(Arrays.copyOfRange(value, value.length - len, value.length)) + } +} + +case class Hypot(left: Expression, right: Expression) + extends BinaryMathExpression(math.hypot, "HYPOT") + case class Logarithm(left: Expression, right: Expression) extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 0d1d5ebdff2d5..b932d4ab850c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types.{DataType, DoubleType, LongType} @@ -226,6 +225,19 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { testBinary(Pow, math.pow, Seq((-1.0, 0.9), (-2.2, 1.7), (-2.2, -1.7)), expectNull = true) } + test("hex") { + checkEvaluation(Hex(Literal(28)), "1C") + checkEvaluation(Hex(Literal(-28)), "FFFFFFFFFFFFFFE4") + checkEvaluation(Hex(Literal(100800200404L)), "177828FED4") + checkEvaluation(Hex(Literal(-100800200404L)), "FFFFFFE887D7012C") + checkEvaluation(Hex(Literal("helloHex")), "68656C6C6F486578") + checkEvaluation(Hex(Literal("helloHex".getBytes())), "68656C6C6F486578") + // scalastyle:off + // Turn off scala style for non-ascii chars + checkEvaluation(Hex(Literal("三重的")), "E4B889E9878DE79A84") + // scalastyle:on + } + test("hypot") { testBinary(Hypot, math.hypot) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index ef92801548a13..5422e066afcb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1046,6 +1046,22 @@ object functions { */ def floor(columnName: String): Column = floor(Column(columnName)) + /** + * Computes hex value of the given column + * + * @group math_funcs + * @since 1.5.0 + */ + def hex(column: Column): Column = Hex(column.expr) + + /** + * Computes hex value of the given input + * + * @group math_funcs + * @since 1.5.0 + */ + def hex(colName: String): Column = hex(Column(colName)) + /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 2768d7dfc8030..d6331aa4ff09e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -212,6 +212,19 @@ class MathExpressionsSuite extends QueryTest { ) } + test("hex") { + val data = Seq((28, -28, 100800200404L, "hello")).toDF("a", "b", "c", "d") + checkAnswer(data.select(hex('a)), Seq(Row("1C"))) + checkAnswer(data.select(hex('b)), Seq(Row("FFFFFFFFFFFFFFE4"))) + checkAnswer(data.select(hex('c)), Seq(Row("177828FED4"))) + checkAnswer(data.select(hex('d)), Seq(Row("68656C6C6F"))) + checkAnswer(data.selectExpr("hex(a)"), Seq(Row("1C"))) + checkAnswer(data.selectExpr("hex(b)"), Seq(Row("FFFFFFFFFFFFFFE4"))) + checkAnswer(data.selectExpr("hex(c)"), Seq(Row("177828FED4"))) + checkAnswer(data.selectExpr("hex(d)"), Seq(Row("68656C6C6F"))) + checkAnswer(data.selectExpr("hex(cast(d as binary))"), Seq(Row("68656C6C6F"))) + } + test("hypot") { testTwoToOneMathFunction(hypot, hypot, math.hypot) } From c6ba2ea341ad23de265d870669b25e6a41f461e5 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Mon, 29 Jun 2015 12:46:33 -0700 Subject: [PATCH 051/122] [SPARK-7862] [SQL] Disable the error message redirect to stderr This is a follow up of #6404, the ScriptTransformation prints the error msg into stderr directly, probably be a disaster for application log. Author: Cheng Hao Closes #6882 from chenghao-intel/verbose and squashes the following commits: bfedd77 [Cheng Hao] revert the write 76ff46b [Cheng Hao] update the CircularBuffer 692b19e [Cheng Hao] check the process exitValue for ScriptTransform 47e0970 [Cheng Hao] Use the RedirectThread instead 1de771d [Cheng Hao] naming the threads in ScriptTransformation 8536e81 [Cheng Hao] disable the error message redirection for stderr --- .../scala/org/apache/spark/util/Utils.scala | 33 ++++++++++++ .../org/apache/spark/util/UtilsSuite.scala | 8 +++ .../spark/sql/hive/client/ClientWrapper.scala | 29 ++--------- .../hive/execution/ScriptTransformation.scala | 51 ++++++++++++------- .../sql/hive/execution/SQLQuerySuite.scala | 2 +- 5 files changed, 77 insertions(+), 46 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 19157af5b6f4d..a7fc749a2b0c6 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -2333,3 +2333,36 @@ private[spark] class RedirectThread( } } } + +/** + * An [[OutputStream]] that will store the last 10 kilobytes (by default) written to it + * in a circular buffer. The current contents of the buffer can be accessed using + * the toString method. + */ +private[spark] class CircularBuffer(sizeInBytes: Int = 10240) extends java.io.OutputStream { + var pos: Int = 0 + var buffer = new Array[Int](sizeInBytes) + + def write(i: Int): Unit = { + buffer(pos) = i + pos = (pos + 1) % buffer.length + } + + override def toString: String = { + val (end, start) = buffer.splitAt(pos) + val input = new java.io.InputStream { + val iterator = (start ++ end).iterator + + def read(): Int = if (iterator.hasNext) iterator.next() else -1 + } + val reader = new BufferedReader(new InputStreamReader(input)) + val stringBuilder = new StringBuilder + var line = reader.readLine() + while (line != null) { + stringBuilder.append(line) + stringBuilder.append("\n") + line = reader.readLine() + } + stringBuilder.toString() + } +} diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index a61ea3918f46a..baa4c661cc21e 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -673,4 +673,12 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { assert(!Utils.isInDirectory(nullFile, parentDir)) assert(!Utils.isInDirectory(nullFile, childFile3)) } + + test("circular buffer") { + val buffer = new CircularBuffer(25) + val stream = new java.io.PrintStream(buffer, true, "UTF-8") + + stream.println("test circular test circular test circular test circular test circular") + assert(buffer.toString === "t circular test circular\n") + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 4c708cec572ae..cbd2bf6b5eede 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -22,6 +22,8 @@ import java.net.URI import java.util.{ArrayList => JArrayList, Map => JMap, List => JList, Set => JSet} import javax.annotation.concurrent.GuardedBy +import org.apache.spark.util.CircularBuffer + import scala.collection.JavaConversions._ import scala.language.reflectiveCalls @@ -66,32 +68,7 @@ private[hive] class ClientWrapper( with Logging { // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. - private val outputBuffer = new java.io.OutputStream { - var pos: Int = 0 - var buffer = new Array[Int](10240) - def write(i: Int): Unit = { - buffer(pos) = i - pos = (pos + 1) % buffer.size - } - - override def toString: String = { - val (end, start) = buffer.splitAt(pos) - val input = new java.io.InputStream { - val iterator = (start ++ end).iterator - - def read(): Int = if (iterator.hasNext) iterator.next() else -1 - } - val reader = new BufferedReader(new InputStreamReader(input)) - val stringBuilder = new StringBuilder - var line = reader.readLine() - while(line != null) { - stringBuilder.append(line) - stringBuilder.append("\n") - line = reader.readLine() - } - stringBuilder.toString() - } - } + private val outputBuffer = new CircularBuffer() private val shim = version match { case hive.v12 => new Shim_v0_12() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 611888055d6cf..b967e191c5855 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} import org.apache.spark.sql.types.DataType -import org.apache.spark.util.Utils +import org.apache.spark.util.{CircularBuffer, RedirectThread, Utils} /** * Transforms the input by forking and running the specified script. @@ -59,15 +59,13 @@ case class ScriptTransformation( child.execute().mapPartitions { iter => val cmd = List("/bin/bash", "-c", script) val builder = new ProcessBuilder(cmd) - // redirectError(Redirect.INHERIT) would consume the error output from buffer and - // then print it to stderr (inherit the target from the current Scala process). - // If without this there would be 2 issues: + // We need to start threads connected to the process pipeline: // 1) The error msg generated by the script process would be hidden. // 2) If the error msg is too big to chock up the buffer, the input logic would be hung - builder.redirectError(Redirect.INHERIT) val proc = builder.start() val inputStream = proc.getInputStream val outputStream = proc.getOutputStream + val errorStream = proc.getErrorStream val reader = new BufferedReader(new InputStreamReader(inputStream)) val (outputSerde, outputSoi) = ioschema.initOutputSerDe(output) @@ -152,29 +150,43 @@ case class ScriptTransformation( val dataOutputStream = new DataOutputStream(outputStream) val outputProjection = new InterpretedProjection(input, child.output) + // TODO make the 2048 configurable? + val stderrBuffer = new CircularBuffer(2048) + // Consume the error stream from the pipeline, otherwise it will be blocked if + // the pipeline is full. + new RedirectThread(errorStream, // input stream from the pipeline + stderrBuffer, // output to a circular buffer + "Thread-ScriptTransformation-STDERR-Consumer").start() + // Put the write(output to the pipeline) into a single thread // and keep the collector as remain in the main thread. // otherwise it will causes deadlock if the data size greater than // the pipeline / buffer capacity. new Thread(new Runnable() { override def run(): Unit = { - iter - .map(outputProjection) - .foreach { row => - if (inputSerde == null) { - val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), - ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") - - outputStream.write(data) - } else { - val writable = inputSerde.serialize( - row.asInstanceOf[GenericInternalRow].values, inputSoi) - prepareWritable(writable).write(dataOutputStream) + Utils.tryWithSafeFinally { + iter + .map(outputProjection) + .foreach { row => + if (inputSerde == null) { + val data = row.mkString("", ioschema.inputRowFormatMap("TOK_TABLEROWFORMATFIELD"), + ioschema.inputRowFormatMap("TOK_TABLEROWFORMATLINES")).getBytes("utf-8") + + outputStream.write(data) + } else { + val writable = inputSerde.serialize( + row.asInstanceOf[GenericInternalRow].values, inputSoi) + prepareWritable(writable).write(dataOutputStream) + } + } + outputStream.close() + } { + if (proc.waitFor() != 0) { + logError(stderrBuffer.toString) // log the stderr circular buffer } } - outputStream.close() } - }).start() + }, "Thread-ScriptTransformation-Feed").start() iterator } @@ -278,3 +290,4 @@ case class HiveScriptIOSchema ( } } } + diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f0aad8dbbe64d..9f7e58f890241 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -653,7 +653,7 @@ class SQLQuerySuite extends QueryTest { .queryExecution.toRdd.count()) } - ignore("test script transform for stderr") { + test("test script transform for stderr") { val data = (1 to 100000).map { i => (i, i, i) } data.toDF("d1", "d2", "d3").registerTempTable("script_trans") assert(0 === From be7ef067620408859144e0244b0f1b8eb56faa86 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 29 Jun 2015 13:15:04 -0700 Subject: [PATCH 052/122] [SPARK-8681] fixed wrong ordering of columns in crosstab I specifically randomized the test. What crosstab does is equivalent to a countByKey, therefore if this test fails again for any reason, we will know that we hit a corner case or something. cc rxin marmbrus Author: Burak Yavuz Closes #7060 from brkyvz/crosstab-fixes and squashes the following commits: 0a65234 [Burak Yavuz] addressed comments v1 d96da7e [Burak Yavuz] fixed wrong ordering of columns in crosstab --- .../sql/execution/stat/StatFunctions.scala | 8 ++++-- .../apache/spark/sql/DataFrameStatSuite.scala | 28 ++++++++++--------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 042e2c9cbb22e..b624ef7e8fa1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -111,7 +111,7 @@ private[sql] object StatFunctions extends Logging { "the pairs. Please try reducing the amount of distinct items in your columns.") } // get the distinct values of column 2, so that we can make them the column names - val distinctCol2 = counts.map(_.get(1)).distinct.zipWithIndex.toMap + val distinctCol2: Map[Any, Int] = counts.map(_.get(1)).distinct.zipWithIndex.toMap val columnSize = distinctCol2.size require(columnSize < 1e4, s"The number of distinct values for $col2, can't " + s"exceed 1e4. Currently $columnSize") @@ -120,14 +120,16 @@ private[sql] object StatFunctions extends Logging { rows.foreach { (row: Row) => // row.get(0) is column 1 // row.get(1) is column 2 - // row.get(3) is the frequency + // row.get(2) is the frequency countsRow.setLong(distinctCol2.get(row.get(1)).get + 1, row.getLong(2)) } // the value of col1 is the first value, the rest are the counts countsRow.update(0, UTF8String.fromString(col1Item.toString)) countsRow }.toSeq - val headerNames = distinctCol2.map(r => StructField(r._1.toString, LongType)).toSeq + // In the map, the column names (._1) are not ordered by the index (._2). This was the bug in + // SPARK-8681. We need to explicitly sort by the column index and assign the column names. + val headerNames = distinctCol2.toSeq.sortBy(_._2).map(r => StructField(r._1.toString, LongType)) val schema = StructType(StructField(tableName, StringType) +: headerNames) new DataFrame(df.sqlContext, LocalRelation(schema.toAttributes, table)).na.fill(0.0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 0d3ff899dad72..64ec1a70c47e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.util.Random + import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite @@ -65,22 +67,22 @@ class DataFrameStatSuite extends SparkFunSuite { } test("crosstab") { - val df = Seq((0, 0), (2, 1), (1, 0), (2, 0), (0, 0), (2, 0)).toDF("a", "b") + val rng = new Random() + val data = Seq.tabulate(25)(i => (rng.nextInt(5), rng.nextInt(10))) + val df = data.toDF("a", "b") val crosstab = df.stat.crosstab("a", "b") val columnNames = crosstab.schema.fieldNames assert(columnNames(0) === "a_b") - assert(columnNames(1) === "0") - assert(columnNames(2) === "1") - val rows: Array[Row] = crosstab.collect().sortBy(_.getString(0)) - assert(rows(0).get(0).toString === "0") - assert(rows(0).getLong(1) === 2L) - assert(rows(0).get(2) === 0L) - assert(rows(1).get(0).toString === "1") - assert(rows(1).getLong(1) === 1L) - assert(rows(1).get(2) === 0L) - assert(rows(2).get(0).toString === "2") - assert(rows(2).getLong(1) === 2L) - assert(rows(2).getLong(2) === 1L) + // reduce by key + val expected = data.map(t => (t, 1)).groupBy(_._1).mapValues(_.length) + val rows = crosstab.collect() + rows.foreach { row => + val i = row.getString(0).toInt + for (col <- 1 to 9) { + val j = columnNames(col).toInt + assert(row.getLong(col) === expected.getOrElse((i, j), 0).toLong) + } + } } test("Frequent Items") { From afae9766f28d2e58297405c39862d20a04267b62 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 29 Jun 2015 13:20:55 -0700 Subject: [PATCH 053/122] [SPARK-8070] [SQL] [PYSPARK] avoid spark jobs in createDataFrame Avoid the unnecessary jobs when infer schema from list. cc yhuai mengxr Author: Davies Liu Closes #6606 from davies/improve_create and squashes the following commits: a5928bf [Davies Liu] Update MimaExcludes.scala 62da911 [Davies Liu] fix mima bab4d7d [Davies Liu] Merge branch 'improve_create' of github.com:davies/spark into improve_create eee44a8 [Davies Liu] Merge branch 'master' of github.com:apache/spark into improve_create 8d9292d [Davies Liu] Update context.py eb24531 [Davies Liu] Update context.py c969997 [Davies Liu] bug fix d5a8ab0 [Davies Liu] fix tests 8c3f10d [Davies Liu] Merge branch 'master' of github.com:apache/spark into improve_create 6ea5925 [Davies Liu] address comments 6ceaeff [Davies Liu] avoid spark jobs in createDataFrame --- python/pyspark/sql/context.py | 64 +++++++++++++++++++++++++---------- python/pyspark/sql/types.py | 48 +++++++++++++++----------- 2 files changed, 75 insertions(+), 37 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index dc239226e6d3c..4dda3b430cfbf 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -203,7 +203,37 @@ def registerFunction(self, name, f, returnType=StringType()): self._sc._javaAccumulator, returnType.json()) + def _inferSchemaFromList(self, data): + """ + Infer schema from list of Row or tuple. + + :param data: list of Row or tuple + :return: StructType + """ + if not data: + raise ValueError("can not infer schema from empty dataset") + first = data[0] + if type(first) is dict: + warnings.warn("inferring schema from dict is deprecated," + "please use pyspark.sql.Row instead") + schema = _infer_schema(first) + if _has_nulltype(schema): + for r in data: + schema = _merge_type(schema, _infer_schema(r)) + if not _has_nulltype(schema): + break + else: + raise ValueError("Some of types cannot be determined after inferring") + return schema + def _inferSchema(self, rdd, samplingRatio=None): + """ + Infer schema from an RDD of Row or tuple. + + :param rdd: an RDD of Row or tuple + :param samplingRatio: sampling ratio, or no sampling (default) + :return: StructType + """ first = rdd.first() if not first: raise ValueError("The first row in RDD is empty, " @@ -322,6 +352,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): data = [r.tolist() for r in data.to_records(index=False)] if not isinstance(data, RDD): + if not isinstance(data, list): + data = list(data) try: # data could be list, tuple, generator ... rdd = self._sc.parallelize(data) @@ -330,28 +362,26 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): else: rdd = data - if schema is None: - schema = self._inferSchema(rdd, samplingRatio) + if schema is None or isinstance(schema, (list, tuple)): + if isinstance(data, RDD): + struct = self._inferSchema(rdd, samplingRatio) + else: + struct = self._inferSchemaFromList(data) + if isinstance(schema, (list, tuple)): + for i, name in enumerate(schema): + struct.fields[i].name = name + schema = struct converter = _create_converter(schema) rdd = rdd.map(converter) - if isinstance(schema, (list, tuple)): - first = rdd.first() - if not isinstance(first, (list, tuple)): - raise TypeError("each row in `rdd` should be list or tuple, " - "but got %r" % type(first)) - row_cls = Row(*schema) - schema = self._inferSchema(rdd.map(lambda r: row_cls(*r)), samplingRatio) - - # take the first few rows to verify schema - rows = rdd.take(10) - # Row() cannot been deserialized by Pyrolite - if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': - rdd = rdd.map(tuple) + elif isinstance(schema, StructType): + # take the first few rows to verify schema rows = rdd.take(10) + for row in rows: + _verify_type(row, schema) - for row in rows: - _verify_type(row, schema) + else: + raise TypeError("schema should be StructType or list or None") # convert python objects to sql data converter = _python_to_sql_converter(schema) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 23d9adb0daea1..932686e5e4b01 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -635,7 +635,7 @@ def _need_python_to_sql_conversion(dataType): >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), ... StructField("values", ArrayType(DoubleType(), False), False)]) >>> _need_python_to_sql_conversion(schema0) - False + True >>> _need_python_to_sql_conversion(ExamplePointUDT()) True >>> schema1 = ArrayType(ExamplePointUDT(), False) @@ -647,7 +647,8 @@ def _need_python_to_sql_conversion(dataType): True """ if isinstance(dataType, StructType): - return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields]) + # convert namedtuple or Row into tuple + return True elif isinstance(dataType, ArrayType): return _need_python_to_sql_conversion(dataType.elementType) elif isinstance(dataType, MapType): @@ -688,21 +689,25 @@ def _python_to_sql_converter(dataType): if isinstance(dataType, StructType): names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) - converters = [_python_to_sql_converter(t) for t in types] - - def converter(obj): - if isinstance(obj, dict): - return tuple(c(obj.get(n)) for n, c in zip(names, converters)) - elif isinstance(obj, tuple): - if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): - return tuple(c(v) for c, v in zip(converters, obj)) - elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs - d = dict(obj) - return tuple(c(d.get(n)) for n, c in zip(names, converters)) + if any(_need_python_to_sql_conversion(t) for t in types): + converters = [_python_to_sql_converter(t) for t in types] + + def converter(obj): + if isinstance(obj, dict): + return tuple(c(obj.get(n)) for n, c in zip(names, converters)) + elif isinstance(obj, tuple): + if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): + return tuple(c(v) for c, v in zip(converters, obj)) + else: + return tuple(c(v) for c, v in zip(converters, obj)) + elif obj is not None: + raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + else: + def converter(obj): + if isinstance(obj, dict): + return tuple(obj.get(n) for n in names) else: - return tuple(c(v) for c, v in zip(converters, obj)) - elif obj is not None: - raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) + return tuple(obj) return converter elif isinstance(dataType, ArrayType): element_converter = _python_to_sql_converter(dataType.elementType) @@ -1027,10 +1032,13 @@ def _verify_type(obj, dataType): _type = type(dataType) assert _type in _acceptable_types, "unknown datatype: %s" % dataType - # subclass of them can not be deserialized in JVM - if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept object in type %s" - % (dataType, type(obj))) + if _type is StructType: + if not isinstance(obj, (tuple, list)): + raise TypeError("StructType can not accept object in type %s" % type(obj)) + else: + # subclass of them can not be deserialized in JVM + if type(obj) not in _acceptable_types[_type]: + raise TypeError("%s can not accept object in type %s" % (dataType, type(obj))) if isinstance(dataType, ArrayType): for i in obj: From 27ef85451cd237caa7016baa69957a35ab365aa8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 29 Jun 2015 14:07:55 -0700 Subject: [PATCH 054/122] [SPARK-8709] Exclude hadoop-client's mockito-all dependency This patch excludes `hadoop-client`'s dependency on `mockito-all`. As of #7061, Spark depends on `mockito-core` instead of `mockito-all`, so the dependency from Hadoop was leading to test compilation failures for some of the Hadoop 2 SBT builds. Author: Josh Rosen Closes #7090 from JoshRosen/SPARK-8709 and squashes the following commits: e190122 [Josh Rosen] [SPARK-8709] Exclude hadoop-client's mockito-all dependency. --- LICENSE | 2 +- core/pom.xml | 10 ---------- launcher/pom.xml | 6 ------ pom.xml | 8 ++++++++ 4 files changed, 9 insertions(+), 17 deletions(-) diff --git a/LICENSE b/LICENSE index 8672be55eca3e..f9e412cade345 100644 --- a/LICENSE +++ b/LICENSE @@ -948,6 +948,6 @@ The following components are provided under the MIT License. See project link fo (MIT License) SLF4J LOG4J-12 Binding (org.slf4j:slf4j-log4j12:1.7.5 - http://www.slf4j.org) (MIT License) pyrolite (org.spark-project:pyrolite:2.0.1 - http://pythonhosted.org/Pyro4/) (MIT License) scopt (com.github.scopt:scopt_2.10:3.2.0 - https://github.com/scopt/scopt) - (The MIT License) Mockito (org.mockito:mockito-core:1.8.5 - http://www.mockito.org) + (The MIT License) Mockito (org.mockito:mockito-core:1.9.5 - http://www.mockito.org) (MIT License) jquery (https://jquery.org/license/) (MIT License) AnchorJS (https://github.com/bryanbraun/anchorjs) diff --git a/core/pom.xml b/core/pom.xml index 565437c4861a4..aee0d92620606 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -69,16 +69,6 @@ org.apache.hadoop hadoop-client - - - javax.servlet - servlet-api - - - org.codehaus.jackson - jackson-mapper-asl - - org.apache.spark diff --git a/launcher/pom.xml b/launcher/pom.xml index a853e67f5cf78..2fd768d8119c4 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -68,12 +68,6 @@ org.apache.hadoop hadoop-client test - - - org.codehaus.jackson - jackson-mapper-asl - - diff --git a/pom.xml b/pom.xml index 4c18bd5e42c87..94dd512cfb618 100644 --- a/pom.xml +++ b/pom.xml @@ -747,6 +747,10 @@ asm asm + + org.codehaus.jackson + jackson-mapper-asl + org.ow2.asm asm @@ -759,6 +763,10 @@ commons-logging commons-logging + + org.mockito + mockito-all + org.mortbay.jetty servlet-api-2.5 From f6fc254ec4ce5f103d45da6d007b4066ce751236 Mon Sep 17 00:00:00 2001 From: Ilya Ganelin Date: Mon, 29 Jun 2015 14:15:15 -0700 Subject: [PATCH 055/122] [SPARK-8056][SQL] Design an easier way to construct schema for both Scala and Python I've added functionality to create new StructType similar to how we add parameters to a new SparkContext. I've also added tests for this type of creation. Author: Ilya Ganelin Closes #6686 from ilganeli/SPARK-8056B and squashes the following commits: 27c1de1 [Ilya Ganelin] Rename 467d836 [Ilya Ganelin] Removed from_string in favor of _parse_Datatype_json_value 5fef5a4 [Ilya Ganelin] Updates for type parsing 4085489 [Ilya Ganelin] Style errors 3670cf5 [Ilya Ganelin] added string to DataType conversion 8109e00 [Ilya Ganelin] Fixed error in tests 41ab686 [Ilya Ganelin] Fixed style errors e7ba7e0 [Ilya Ganelin] Moved some python tests to tests.py. Added cleaner handling of null data type and added test for correctness of input format 15868fa [Ilya Ganelin] Fixed python errors b79b992 [Ilya Ganelin] Merge remote-tracking branch 'upstream/master' into SPARK-8056B a3369fc [Ilya Ganelin] Fixing space errors e240040 [Ilya Ganelin] Style bab7823 [Ilya Ganelin] Constructor error 73d4677 [Ilya Ganelin] Style 4ed00d9 [Ilya Ganelin] Fixed default arg 67df57a [Ilya Ganelin] Removed Foo 04cbf0c [Ilya Ganelin] Added comments for single object 0484d7a [Ilya Ganelin] Restored second method 6aeb740 [Ilya Ganelin] Style 689e54d [Ilya Ganelin] Style f497e9e [Ilya Ganelin] Got rid of old code e3c7a88 [Ilya Ganelin] Fixed doctest failure a62ccde [Ilya Ganelin] Style 966ac06 [Ilya Ganelin] style checks dabb7e6 [Ilya Ganelin] Added Python tests a3f4152 [Ilya Ganelin] added python bindings and better comments e6e536c [Ilya Ganelin] Added extra space 7529a2e [Ilya Ganelin] Fixed formatting d388f86 [Ilya Ganelin] Fixed small bug c4e3bf5 [Ilya Ganelin] Reverted to using parse. Updated parse to support long d7634b6 [Ilya Ganelin] Reverted to fromString to properly support types 22c39d5 [Ilya Ganelin] replaced FromString with DataTypeParser.parse. Replaced empty constructor initializing a null to have it instead create a new array to allow appends to it. faca398 [Ilya Ganelin] [SPARK-8056] Replaced default argument usage. Updated usage and code for DataType.fromString 1acf76e [Ilya Ganelin] Scala style e31c674 [Ilya Ganelin] Fixed bug in test 8dc0795 [Ilya Ganelin] Added tests for creation of StructType object with new methods fdf7e9f [Ilya Ganelin] [SPARK-8056] Created add methods to facilitate building new StructType objects. --- python/pyspark/sql/tests.py | 29 +++++ python/pyspark/sql/types.py | 52 ++++++++- .../spark/sql/types/DataTypeParser.scala | 2 +- .../apache/spark/sql/types/StructType.scala | 104 +++++++++++++++++- .../spark/sql/types/DataTypeSuite.scala | 31 ++++++ 5 files changed, 212 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index ffee43a94baba..34f397d0ffef0 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -516,6 +516,35 @@ def test_between_function(self): self.assertEqual([Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)], df.filter(df.a.between(df.b, df.c)).collect()) + def test_struct_type(self): + from pyspark.sql.types import StructType, StringType, StructField + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + struct2 = StructType([StructField("f1", StringType(), True), + StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1, struct2) + + struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1, struct2) + + struct1 = (StructType().add(StructField("f1", StringType(), True)) + .add(StructField("f2", StringType(), True, None))) + struct2 = StructType([StructField("f1", StringType(), True), + StructField("f2", StringType(), True, None)]) + self.assertEqual(struct1, struct2) + + struct1 = (StructType().add(StructField("f1", StringType(), True)) + .add(StructField("f2", StringType(), True, None))) + struct2 = StructType([StructField("f1", StringType(), True)]) + self.assertNotEqual(struct1, struct2) + + # Catch exception raised during improper construction + try: + struct1 = StructType().add("name") + self.assertEqual(1, 0) + except ValueError: + self.assertEqual(1, 1) + def test_save_and_load(self): df = self.df tmpPath = tempfile.mkdtemp() diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 932686e5e4b01..ae9344e6106a4 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -355,8 +355,7 @@ class StructType(DataType): This is the data type representing a :class:`Row`. """ - - def __init__(self, fields): + def __init__(self, fields=None): """ >>> struct1 = StructType([StructField("f1", StringType(), True)]) >>> struct2 = StructType([StructField("f1", StringType(), True)]) @@ -368,8 +367,53 @@ def __init__(self, fields): >>> struct1 == struct2 False """ - assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType" - self.fields = fields + if not fields: + self.fields = [] + else: + self.fields = fields + assert all(isinstance(f, StructField) for f in fields),\ + "fields should be a list of StructField" + + def add(self, field, data_type=None, nullable=True, metadata=None): + """ + Construct a StructType by adding new elements to it to define the schema. The method accepts + either: + a) A single parameter which is a StructField object. + b) Between 2 and 4 parameters as (name, data_type, nullable (optional), + metadata(optional). The data_type parameter may be either a String or a DataType object + + >>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) + >>> struct2 = StructType([StructField("f1", StringType(), True),\ + StructField("f2", StringType(), True, None)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType().add(StructField("f1", StringType(), True)) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) + >>> struct1 == struct2 + True + >>> struct1 = StructType().add("f1", "string", True) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) + >>> struct1 == struct2 + True + + :param field: Either the name of the field or a StructField object + :param data_type: If present, the DataType of the StructField to create + :param nullable: Whether the field to add should be nullable (default True) + :param metadata: Any additional metadata (default None) + :return: a new updated StructType + """ + if isinstance(field, StructField): + self.fields.append(field) + else: + if isinstance(field, str) and data_type is None: + raise ValueError("Must specify DataType if passing name of struct_field to create.") + + if isinstance(data_type, str): + data_type_f = _parse_datatype_json_value(data_type) + else: + data_type_f = data_type + self.fields.append(StructField(field, data_type_f, nullable, metadata)) + return self def simpleString(self): return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala index 04f3379afb38d..6b43224feb1f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataTypeParser.scala @@ -44,7 +44,7 @@ private[sql] trait DataTypeParser extends StandardTokenParsers { "(?i)tinyint".r ^^^ ByteType | "(?i)smallint".r ^^^ ShortType | "(?i)double".r ^^^ DoubleType | - "(?i)bigint".r ^^^ LongType | + "(?i)(?:bigint|long)".r ^^^ LongType | "(?i)binary".r ^^^ BinaryType | "(?i)boolean".r ^^^ BooleanType | fixedDecimalType | diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 193c08a4d0df7..2db0a359e9db5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -94,7 +94,7 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} case class StructType(fields: Array[StructField]) extends DataType with Seq[StructField] { /** No-arg constructor for kryo. */ - protected def this() = this(null) + def this() = this(Array.empty[StructField]) /** Returns all field names in an array. */ def fieldNames: Array[String] = fields.map(_.name) @@ -103,6 +103,108 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru private lazy val nameToField: Map[String, StructField] = fields.map(f => f.name -> f).toMap private lazy val nameToIndex: Map[String, Int] = fieldNames.zipWithIndex.toMap + /** + * Creates a new [[StructType]] by adding a new field. + * {{{ + * val struct = (new StructType) + * .add(StructField("a", IntegerType, true)) + * .add(StructField("b", LongType, false)) + * .add(StructField("c", StringType, true)) + *}}} + */ + def add(field: StructField): StructType = { + StructType(fields :+ field) + } + + /** + * Creates a new [[StructType]] by adding a new nullable field with no metadata. + * + * val struct = (new StructType) + * .add("a", IntegerType) + * .add("b", LongType) + * .add("c", StringType) + */ + def add(name: String, dataType: DataType): StructType = { + StructType(fields :+ new StructField(name, dataType, nullable = true, Metadata.empty)) + } + + /** + * Creates a new [[StructType]] by adding a new field with no metadata. + * + * val struct = (new StructType) + * .add("a", IntegerType, true) + * .add("b", LongType, false) + * .add("c", StringType, true) + */ + def add(name: String, dataType: DataType, nullable: Boolean): StructType = { + StructType(fields :+ new StructField(name, dataType, nullable, Metadata.empty)) + } + + /** + * Creates a new [[StructType]] by adding a new field and specifying metadata. + * {{{ + * val struct = (new StructType) + * .add("a", IntegerType, true, Metadata.empty) + * .add("b", LongType, false, Metadata.empty) + * .add("c", StringType, true, Metadata.empty) + * }}} + */ + def add( + name: String, + dataType: DataType, + nullable: Boolean, + metadata: Metadata): StructType = { + StructType(fields :+ new StructField(name, dataType, nullable, metadata)) + } + + /** + * Creates a new [[StructType]] by adding a new nullable field with no metadata where the + * dataType is specified as a String. + * + * {{{ + * val struct = (new StructType) + * .add("a", "int") + * .add("b", "long") + * .add("c", "string") + * }}} + */ + def add(name: String, dataType: String): StructType = { + add(name, DataTypeParser.parse(dataType), nullable = true, Metadata.empty) + } + + /** + * Creates a new [[StructType]] by adding a new field with no metadata where the + * dataType is specified as a String. + * + * {{{ + * val struct = (new StructType) + * .add("a", "int", true) + * .add("b", "long", false) + * .add("c", "string", true) + * }}} + */ + def add(name: String, dataType: String, nullable: Boolean): StructType = { + add(name, DataTypeParser.parse(dataType), nullable, Metadata.empty) + } + + /** + * Creates a new [[StructType]] by adding a new field and specifying metadata where the + * dataType is specified as a String. + * {{{ + * val struct = (new StructType) + * .add("a", "int", true, Metadata.empty) + * .add("b", "long", false, Metadata.empty) + * .add("c", "string", true, Metadata.empty) + * }}} + */ + def add( + name: String, + dataType: String, + nullable: Boolean, + metadata: Metadata): StructType = { + add(name, DataTypeParser.parse(dataType), nullable, metadata) + } + /** * Extracts a [[StructField]] of the given name. If the [[StructType]] object does not * have a name matching the given name, `null` will be returned. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala index 077c0ad70ac4f..14e7b4a9561b6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeSuite.scala @@ -33,6 +33,37 @@ class DataTypeSuite extends SparkFunSuite { assert(MapType(StringType, IntegerType, true) === map) } + test("construct with add") { + val struct = (new StructType) + .add("a", IntegerType, true) + .add("b", LongType, false) + .add("c", StringType, true) + + assert(StructField("b", LongType, false) === struct("b")) + } + + test("construct with add from StructField") { + // Test creation from StructField type + val struct = (new StructType) + .add(StructField("a", IntegerType, true)) + .add(StructField("b", LongType, false)) + .add(StructField("c", StringType, true)) + + assert(StructField("b", LongType, false) === struct("b")) + } + + test("construct with String DataType") { + // Test creation with DataType as String + val struct = (new StructType) + .add("a", "int", true) + .add("b", "long", false) + .add("c", "string", true) + + assert(StructField("a", IntegerType, true) === struct("a")) + assert(StructField("b", LongType, false) === struct("b")) + assert(StructField("c", StringType, true) === struct("c")) + } + test("extract fields from a StructType") { val struct = StructType( StructField("a", IntegerType, true) :: From ecd3aacf2805bb231cfb44bab079319cfe73c3f1 Mon Sep 17 00:00:00 2001 From: Ai He Date: Mon, 29 Jun 2015 14:36:26 -0700 Subject: [PATCH 056/122] [SPARK-7810] [PYSPARK] solve python rdd socket connection problem Method "_load_from_socket" in rdd.py cannot load data from jvm socket when ipv6 is used. The current method only works well with ipv4. New modification should work around both two protocols. Author: Ai He Author: AiHe Closes #6338 from AiHe/pyspark-networking-issue and squashes the following commits: d4fc9c4 [Ai He] handle code review 2 e75c5c8 [Ai He] handle code review 5644953 [AiHe] solve python rdd socket connection problem to jvm --- python/pyspark/rdd.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 1b64be23a667e..cb20bc8b54027 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -121,10 +121,22 @@ def _parse_memory(s): def _load_from_socket(port, serializer): - sock = socket.socket() - sock.settimeout(3) + sock = None + # Support for both IPv4 and IPv6. + # On most of IPv6-ready systems, IPv6 will take precedence. + for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): + af, socktype, proto, canonname, sa = res + try: + sock = socket.socket(af, socktype, proto) + sock.settimeout(3) + sock.connect(sa) + except socket.error: + sock = None + continue + break + if not sock: + raise Exception("could not open socket") try: - sock.connect(("localhost", port)) rf = sock.makefile("rb", 65536) for item in serializer.load_stream(rf): yield item From c8ae887ef02b8f7e2ad06841719fb12eacf1f7f9 Mon Sep 17 00:00:00 2001 From: Rosstin Date: Mon, 29 Jun 2015 14:45:08 -0700 Subject: [PATCH 057/122] [SPARK-8660][ML] Convert JavaDoc style comments inLogisticRegressionSuite.scala to regular multiline comments, to make copy-pasting R commands easier Converted JavaDoc style comments in mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala to regular multiline comments, to make copy-pasting R commands easier. Author: Rosstin Closes #7096 from Rosstin/SPARK-8660 and squashes the following commits: 242aedd [Rosstin] SPARK-8660, changed comment style from JavaDoc style to normal multiline comment in order to make copypaste into R easier, in file classification/LogisticRegressionSuite.scala 2cd2985 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8639 21ac1e5 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8639 6c18058 [Rosstin] fixed minor typos in docs/README.md and docs/api.md --- .../LogisticRegressionSuite.scala | 342 +++++++++--------- 1 file changed, 171 insertions(+), 171 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 5a6265ea992c6..bc6eeac1db5da 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -36,19 +36,19 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { dataset = sqlContext.createDataFrame(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42)) - /** - * Here is the instruction describing how to export the test data into CSV format - * so we can validate the training accuracy compared with R's glmnet package. - * - * import org.apache.spark.mllib.classification.LogisticRegressionSuite - * val nPoints = 10000 - * val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) - * val xMean = Array(5.843, 3.057, 3.758, 1.199) - * val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) - * val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput( - * weights, xMean, xVariance, true, nPoints, 42), 1) - * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", " - * + x.features(2) + ", " + x.features(3)).saveAsTextFile("path") + /* + Here is the instruction describing how to export the test data into CSV format + so we can validate the training accuracy compared with R's glmnet package. + + import org.apache.spark.mllib.classification.LogisticRegressionSuite + val nPoints = 10000 + val weights = Array(-0.57997, 0.912083, -0.371077, -0.819866, 2.688191) + val xMean = Array(5.843, 3.057, 3.758, 1.199) + val xVariance = Array(0.6856, 0.1899, 3.116, 0.581) + val data = sc.parallelize(LogisticRegressionSuite.generateMultinomialLogisticInput( + weights, xMean, xVariance, true, nPoints, 42), 1) + data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1) + ", " + + x.features(2) + ", " + x.features(3)).saveAsTextFile("path") */ binaryDataset = { val nPoints = 10000 @@ -211,22 +211,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = (new LogisticRegression).setFitIntercept(true) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 2.8366423 - * data.V2 -0.5895848 - * data.V3 0.8931147 - * data.V4 -0.3925051 - * data.V5 -0.7996864 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 2.8366423 + data.V2 -0.5895848 + data.V3 0.8931147 + data.V4 -0.3925051 + data.V5 -0.7996864 */ val interceptR = 2.8366423 val weightsR = Array(-0.5895848, 0.8931147, -0.3925051, -0.7996864) @@ -242,23 +242,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = (new LogisticRegression).setFitIntercept(false) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = - * coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 -0.3534996 - * data.V3 1.2964482 - * data.V4 -0.3571741 - * data.V5 -0.7407946 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = + coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 0, intercept=FALSE)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.3534996 + data.V3 1.2964482 + data.V4 -0.3571741 + data.V5 -0.7407946 */ val interceptR = 0.0 val weightsR = Array(-0.3534996, 1.2964482, -0.3571741, -0.7407946) @@ -275,22 +275,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(1.0).setRegParam(0.12) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) -0.05627428 - * data.V2 . - * data.V3 . - * data.V4 -0.04325749 - * data.V5 -0.02481551 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) -0.05627428 + data.V2 . + data.V3 . + data.V4 -0.04325749 + data.V5 -0.02481551 */ val interceptR = -0.05627428 val weightsR = Array(0.0, 0.0, -0.04325749, -0.02481551) @@ -307,23 +307,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(1.0).setRegParam(0.12) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, - * intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 . - * data.V3 . - * data.V4 -0.05189203 - * data.V5 -0.03891782 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 1, lambda = 0.12, + intercept=FALSE)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 . + data.V3 . + data.V4 -0.05189203 + data.V5 -0.03891782 */ val interceptR = 0.0 val weightsR = Array(0.0, 0.0, -0.05189203, -0.03891782) @@ -340,22 +340,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(0.0).setRegParam(1.37) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 0.15021751 - * data.V2 -0.07251837 - * data.V3 0.10724191 - * data.V4 -0.04865309 - * data.V5 -0.10062872 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.15021751 + data.V2 -0.07251837 + data.V3 0.10724191 + data.V4 -0.04865309 + data.V5 -0.10062872 */ val interceptR = 0.15021751 val weightsR = Array(-0.07251837, 0.10724191, -0.04865309, -0.10062872) @@ -372,23 +372,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(0.0).setRegParam(1.37) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, - * intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 -0.06099165 - * data.V3 0.12857058 - * data.V4 -0.04708770 - * data.V5 -0.09799775 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0, lambda = 1.37, + intercept=FALSE)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.06099165 + data.V3 0.12857058 + data.V4 -0.04708770 + data.V5 -0.09799775 */ val interceptR = 0.0 val weightsR = Array(-0.06099165, 0.12857058, -0.04708770, -0.09799775) @@ -405,22 +405,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(0.38).setRegParam(0.21) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 0.57734851 - * data.V2 -0.05310287 - * data.V3 . - * data.V4 -0.08849250 - * data.V5 -0.15458796 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 0.57734851 + data.V2 -0.05310287 + data.V3 . + data.V4 -0.08849250 + data.V5 -0.15458796 */ val interceptR = 0.57734851 val weightsR = Array(-0.05310287, 0.0, -0.08849250, -0.15458796) @@ -437,23 +437,23 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setElasticNetParam(0.38).setRegParam(0.21) val model = trainer.fit(binaryDataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, - * intercept=FALSE)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * data.V2 -0.001005743 - * data.V3 0.072577857 - * data.V4 -0.081203769 - * data.V5 -0.142534158 + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 0.38, lambda = 0.21, + intercept=FALSE)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + data.V2 -0.001005743 + data.V3 0.072577857 + data.V4 -0.081203769 + data.V5 -0.142534158 */ val interceptR = 0.0 val weightsR = Array(-0.001005743, 0.072577857, -0.081203769, -0.142534158) @@ -480,16 +480,16 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { classSummarizer1.merge(classSummarizer2) }).histogram - /** - * For binary logistic regression with strong L1 regularization, all the weights will be zeros. - * As a result, - * {{{ - * P(0) = 1 / (1 + \exp(b)), and - * P(1) = \exp(b) / (1 + \exp(b)) - * }}}, hence - * {{{ - * b = \log{P(1) / P(0)} = \log{count_1 / count_0} - * }}} + /* + For binary logistic regression with strong L1 regularization, all the weights will be zeros. + As a result, + {{{ + P(0) = 1 / (1 + \exp(b)), and + P(1) = \exp(b) / (1 + \exp(b)) + }}}, hence + {{{ + b = \log{P(1) / P(0)} = \log{count_1 / count_0} + }}} */ val interceptTheory = math.log(histogram(1).toDouble / histogram(0).toDouble) val weightsTheory = Array(0.0, 0.0, 0.0, 0.0) @@ -500,22 +500,22 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.weights(2) ~== weightsTheory(2) absTol 1E-6) assert(model.weights(3) ~== weightsTheory(3) absTol 1E-6) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * > library("glmnet") - * > data <- read.csv("path", header=FALSE) - * > label = factor(data$V1) - * > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) - * > weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) - * > weights - * 5 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) -0.2480643 - * data.V2 0.0000000 - * data.V3 . - * data.V4 . - * data.V5 . + /* + Using the following R code to load the data and train the model using glmnet package. + + > library("glmnet") + > data <- read.csv("path", header=FALSE) + > label = factor(data$V1) + > features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + > weights = coef(glmnet(features,label, family="binomial", alpha = 1.0, lambda = 6.0)) + > weights + 5 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) -0.2480643 + data.V2 0.0000000 + data.V3 . + data.V4 . + data.V5 . */ val interceptR = -0.248065 val weightsR = Array(0.0, 0.0, 0.0, 0.0) From 931da5c8ab271ff2ee04419c7e3c6b0012459694 Mon Sep 17 00:00:00 2001 From: BenFradet Date: Mon, 29 Jun 2015 15:27:13 -0700 Subject: [PATCH 058/122] [SPARK-8478] [SQL] Harmonize UDF-related code to use uniformly UDF instead of Udf Follow-up of #6902 for being coherent between ```Udf``` and ```UDF``` Author: BenFradet Closes #6920 from BenFradet/SPARK-8478 and squashes the following commits: c500f29 [BenFradet] renamed a few variables in functions to use UDF 8ab0f2d [BenFradet] renamed idUdf to idUDF in SQLQuerySuite 98696c2 [BenFradet] renamed originalUdfs in TestHive to originalUDFs 7738f74 [BenFradet] modified HiveUDFSuite to use only UDF c52608d [BenFradet] renamed HiveUdfSuite to HiveUDFSuite e51b9ac [BenFradet] renamed ExtractPythonUdfs to ExtractPythonUDFs 8c756f1 [BenFradet] renamed Hive UDF related code 2a1ca76 [BenFradet] renamed pythonUdfs to pythonUDFs 261e6fb [BenFradet] renamed ScalaUdf to ScalaUDF --- .../{ScalaUdf.scala => ScalaUDF.scala} | 4 +- .../org/apache/spark/sql/SQLContext.scala | 4 +- .../apache/spark/sql/UDFRegistration.scala | 96 +++++++++--------- .../spark/sql/UserDefinedFunction.scala | 4 +- .../{pythonUdfs.scala => pythonUDFs.scala} | 2 +- .../org/apache/spark/sql/functions.scala | 34 +++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 4 +- .../apache/spark/sql/hive/HiveContext.scala | 4 +- .../org/apache/spark/sql/hive/HiveQl.scala | 2 +- .../hive/{hiveUdfs.scala => hiveUDFs.scala} | 26 ++--- .../apache/spark/sql/hive/test/TestHive.scala | 4 +- .../files/{testUdf => testUDF}/part-00000 | Bin ...{HiveUdfSuite.scala => HiveUDFSuite.scala} | 24 ++--- 13 files changed, 104 insertions(+), 104 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/{ScalaUdf.scala => ScalaUDF.scala} (99%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{pythonUdfs.scala => pythonUDFs.scala} (99%) rename sql/hive/src/main/scala/org/apache/spark/sql/hive/{hiveUdfs.scala => hiveUDFs.scala} (96%) rename sql/hive/src/test/resources/data/files/{testUdf => testUDF}/part-00000 (100%) rename sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/{HiveUdfSuite.scala => HiveUDFSuite.scala} (93%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala similarity index 99% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 55df72f102295..dbb4381d54c4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.types.DataType * User-defined function. * @param dataType Return type of function. */ -case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression]) +case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expression]) extends Expression { override def nullable: Boolean = true @@ -957,6 +957,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) override def eval(input: InternalRow): Any = converter(f(input)) - // TODO(davies): make ScalaUdf work with codegen + // TODO(davies): make ScalaUDF work with codegen override def isThreadSafe: Boolean = false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 8ed44ee141be5..fc14a77538ef1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -146,7 +146,7 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] lazy val analyzer: Analyzer = new Analyzer(catalog, functionRegistry, conf) { override val extendedResolutionRules = - ExtractPythonUdfs :: + ExtractPythonUDFs :: sources.PreInsertCastAndRename :: Nil @@ -257,7 +257,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * * The following example registers a Scala closure as UDF: * {{{ - * sqlContext.udf.register("myUdf", (arg1: Int, arg2: String) => arg2 + arg1) + * sqlContext.udf.register("myUDF", (arg1: Int, arg2: String) => arg2 + arg1) * }}} * * The following example registers a UDF in Java: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 3cc5c2441d8a5..03dc37aa73f0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -26,7 +26,7 @@ import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUdf} +import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType @@ -95,7 +95,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) }""") @@ -114,7 +114,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType) = { | functionRegistry.registerFunction( | name, - | (e: Seq[Expression]) => ScalaUdf(f$anyCast.call($anyParams), returnType, e)) + | (e: Seq[Expression]) => ScalaUDF(f$anyCast.call($anyParams), returnType, e)) |}""".stripMargin) } */ @@ -126,7 +126,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -138,7 +138,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -150,7 +150,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -162,7 +162,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -174,7 +174,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -186,7 +186,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -198,7 +198,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -210,7 +210,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -222,7 +222,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -234,7 +234,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -246,7 +246,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -258,7 +258,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -270,7 +270,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -282,7 +282,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -294,7 +294,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -306,7 +306,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -318,7 +318,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -330,7 +330,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -342,7 +342,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -354,7 +354,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -366,7 +366,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -378,7 +378,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -390,7 +390,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - def builder(e: Seq[Expression]) = ScalaUdf(func, dataType, e) + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType) } @@ -405,7 +405,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF1[_, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), returnType, e)) } /** @@ -415,7 +415,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF2[_, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), returnType, e)) } /** @@ -425,7 +425,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF3[_, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), returnType, e)) } /** @@ -435,7 +435,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -445,7 +445,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -455,7 +455,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -465,7 +465,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -475,7 +475,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -485,7 +485,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -495,7 +495,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -505,7 +505,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -515,7 +515,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -525,7 +525,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -535,7 +535,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -545,7 +545,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -555,7 +555,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -565,7 +565,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -575,7 +575,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -585,7 +585,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -595,7 +595,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -605,7 +605,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } /** @@ -615,7 +615,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { functionRegistry.registerFunction( name, - (e: Seq[Expression]) => ScalaUdf(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) + (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } // scalastyle:on diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index a02e202d2eebc..831eb7eb0fae9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -23,7 +23,7 @@ import org.apache.spark.Accumulator import org.apache.spark.annotation.Experimental import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.expressions.ScalaUdf +import org.apache.spark.sql.catalyst.expressions.ScalaUDF import org.apache.spark.sql.execution.PythonUDF import org.apache.spark.sql.types.DataType @@ -44,7 +44,7 @@ import org.apache.spark.sql.types.DataType case class UserDefinedFunction protected[sql] (f: AnyRef, dataType: DataType) { def apply(exprs: Column*): Column = { - Column(ScalaUdf(f, dataType, exprs.map(_.expr))) + Column(ScalaUDF(f, dataType, exprs.map(_.expr))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index 036f5d253e385..9e1cff06c7eea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -69,7 +69,7 @@ private[spark] case class PythonUDF( * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { +private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // Skip EvaluatePython nodes. case plan: EvaluatePython => plan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5422e066afcb1..4d9a019058228 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1509,7 +1509,7 @@ object functions { (0 to 10).map { x => val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") val fTypes = Seq.fill(x + 1)("_").mkString(", ") - val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + val argsInUDF = (1 to x).map(i => s"arg$i.expr").mkString(", ") println(s""" /** * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires @@ -1521,7 +1521,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { - ScalaUdf(f, returnType, Seq($argsInUdf)) + ScalaUDF(f, returnType, Seq($argsInUDF)) }""") } } @@ -1659,7 +1659,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function0[_], returnType: DataType): Column = { - ScalaUdf(f, returnType, Seq()) + ScalaUDF(f, returnType, Seq()) } /** @@ -1672,7 +1672,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr)) } /** @@ -1685,7 +1685,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr)) } /** @@ -1698,7 +1698,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) } /** @@ -1711,7 +1711,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) } /** @@ -1724,7 +1724,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) } /** @@ -1737,7 +1737,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) } /** @@ -1750,7 +1750,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) } /** @@ -1763,7 +1763,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) } /** @@ -1776,7 +1776,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) } /** @@ -1789,7 +1789,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) } // scalastyle:on @@ -1802,8 +1802,8 @@ object functions { * * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") * val sqlContext = df.sqlContext - * sqlContext.udf.register("simpleUdf", (v: Int) => v * v) - * df.select($"id", callUDF("simpleUdf", $"value")) + * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) + * df.select($"id", callUDF("simpleUDF", $"value")) * }}} * * @group udf_funcs @@ -1821,8 +1821,8 @@ object functions { * * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") * val sqlContext = df.sqlContext - * sqlContext.udf.register("simpleUdf", (v: Int) => v * v) - * df.select($"id", callUdf("simpleUdf", $"value")) + * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) + * df.select($"id", callUdf("simpleUDF", $"value")) * }}} * * @group udf_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 22c54e43c1d16..82dc0e9ce5132 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -140,9 +140,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index") // we except the id is materialized once - val idUdf = udf(() => UUID.randomUUID().toString) + val idUDF = udf(() => UUID.randomUUID().toString) - val dfWithId = df.withColumn("id", idUdf()) + val dfWithId = df.withColumn("id", idUDF()) // Make a new DataFrame (actually the same reference to the old one) val cached = dfWithId.cache() // Trigger the cache diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 8021f915bb821..b91242af2d155 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.SQLConf.SQLConfEntry._ import org.apache.spark.sql.catalyst.ParserDialect import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand} +import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} import org.apache.spark.sql.sources.DataSourceStrategy @@ -381,7 +381,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.ParquetConversions :: catalog.CreateTables :: catalog.PreInsertionCasts :: - ExtractPythonUdfs :: + ExtractPythonUDFs :: ResolveHiveWindowFunction :: sources.PreInsertCastAndRename :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 7c4620952ba4b..2de7a99c122fd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1638,7 +1638,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C sys.error(s"Couldn't find function $functionName")) val functionClassName = functionInfo.getFunctionClass.getName - (HiveGenericUdtf( + (HiveGenericUDTF( new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)), attributes) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala similarity index 96% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 4986b1ea9d906..d7827d56ca8c5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -59,16 +59,16 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) val functionClassName = functionInfo.getFunctionClass.getName if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children) + HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children) + HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children) + HiveGenericUDAF(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUdaf(new HiveFunctionWrapper(functionClassName), children) + HiveUDAF(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children) + HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) } else { sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") } @@ -79,7 +79,7 @@ private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) throw new UnsupportedOperationException } -private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with Logging { type UDFType = UDF @@ -146,7 +146,7 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector) override def get(): AnyRef = wrap(func(), oi) } -private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with Logging { type UDFType = GenericUDF @@ -413,7 +413,7 @@ private[hive] case class HiveWindowFunction( new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) } -private[hive] case class HiveGenericUdaf( +private[hive] case class HiveGenericUDAF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends AggregateExpression with HiveInspectors { @@ -441,11 +441,11 @@ private[hive] case class HiveGenericUdaf( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this) + def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this) } /** It is used as a wrapper for the hive functions which uses UDAF interface */ -private[hive] case class HiveUdaf( +private[hive] case class HiveUDAF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends AggregateExpression with HiveInspectors { @@ -474,7 +474,7 @@ private[hive] case class HiveUdaf( s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } - def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this, true) + def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this, true) } /** @@ -488,7 +488,7 @@ private[hive] case class HiveUdaf( * Operators that require maintaining state in between input rows should instead be implemented as * user defined aggregations, which have clean semantics even in a partitioned execution. */ -private[hive] case class HiveGenericUdtf( +private[hive] case class HiveGenericUDTF( funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Generator with HiveInspectors { @@ -553,7 +553,7 @@ private[hive] case class HiveGenericUdtf( } } -private[hive] case class HiveUdafFunction( +private[hive] case class HiveUDAFFunction( funcWrapper: HiveFunctionWrapper, exprs: Seq[Expression], base: AggregateExpression, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index ea325cc93cb85..7978fdacaedba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -391,7 +391,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { * Records the UDFs present when the server starts, so we can delete ones that are created by * tests. */ - protected val originalUdfs: JavaSet[String] = FunctionRegistry.getFunctionNames + protected val originalUDFs: JavaSet[String] = FunctionRegistry.getFunctionNames /** * Resets the test instance by deleting any tables that have been created. @@ -410,7 +410,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { catalog.client.reset() catalog.unregisterAllTables() - FunctionRegistry.getFunctionNames.filterNot(originalUdfs.contains(_)).foreach { udfName => + FunctionRegistry.getFunctionNames.filterNot(originalUDFs.contains(_)).foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } diff --git a/sql/hive/src/test/resources/data/files/testUdf/part-00000 b/sql/hive/src/test/resources/data/files/testUDF/part-00000 similarity index 100% rename from sql/hive/src/test/resources/data/files/testUdf/part-00000 rename to sql/hive/src/test/resources/data/files/testUDF/part-00000 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala similarity index 93% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala rename to sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index ce5985888f540..56b0bef1d0571 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -46,7 +46,7 @@ case class ListStringCaseClass(l: Seq[String]) /** * A test suite for Hive custom UDFs. */ -class HiveUdfSuite extends QueryTest { +class HiveUDFSuite extends QueryTest { import TestHive.{udf, sql} import TestHive.implicits._ @@ -73,7 +73,7 @@ class HiveUdfSuite extends QueryTest { test("hive struct udf") { sql( """ - |CREATE EXTERNAL TABLE hiveUdfTestTable ( + |CREATE EXTERNAL TABLE hiveUDFTestTable ( | pair STRUCT |) |PARTITIONED BY (partition STRING) @@ -82,15 +82,15 @@ class HiveUdfSuite extends QueryTest { """. stripMargin.format(classOf[PairSerDe].getName)) - val location = Utils.getSparkClassLoader.getResource("data/files/testUdf").getFile + val location = Utils.getSparkClassLoader.getResource("data/files/testUDF").getFile sql(s""" - ALTER TABLE hiveUdfTestTable - ADD IF NOT EXISTS PARTITION(partition='testUdf') + ALTER TABLE hiveUDFTestTable + ADD IF NOT EXISTS PARTITION(partition='testUDF') LOCATION '$location'""") - sql(s"CREATE TEMPORARY FUNCTION testUdf AS '${classOf[PairUdf].getName}'") - sql("SELECT testUdf(pair) FROM hiveUdfTestTable") - sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") + sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[PairUDF].getName}'") + sql("SELECT testUDF(pair) FROM hiveUDFTestTable") + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF") } test("SPARK-6409 UDAFAverage test") { @@ -169,11 +169,11 @@ class HiveUdfSuite extends QueryTest { StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() testData.registerTempTable("stringTable") - sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") + sql(s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") checkAnswer( - sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), + sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"), Seq(Row("hello world"), Row("hello goodbye"))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") + sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF") TestHive.reset() } @@ -244,7 +244,7 @@ class PairSerDe extends AbstractSerDe { } } -class PairUdf extends GenericUDF { +class PairUDF extends GenericUDF { override def initialize(p1: Array[ObjectInspector]): ObjectInspector = ObjectInspectorFactory.getStandardStructObjectInspector( Seq("id", "value"), From ed359de595d5dd67b666660eddf092eaf89041c8 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 29 Jun 2015 15:59:20 -0700 Subject: [PATCH 059/122] [SPARK-8579] [SQL] support arbitrary object in UnsafeRow This PR brings arbitrary object support in UnsafeRow (both in grouping key and aggregation buffer). Two object pools will be created to hold those non-primitive objects, and put the index of them into UnsafeRow. In order to compare the grouping key as bytes, the objects in key will be stored in a unique object pool, to make sure same objects will have same index (used as hashCode). For StringType and BinaryType, we still put them as var-length in UnsafeRow when initializing for better performance. But for update, they will be an object inside object pools (there will be some garbages left in the buffer). BTW: Will create a JIRA once issue.apache.org is available. cc JoshRosen rxin Author: Davies Liu Closes #6959 from davies/unsafe_obj and squashes the following commits: 5ce39da [Davies Liu] fix comment 5e797bf [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj 5803d64 [Davies Liu] fix conflict 461d304 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj 2f41c90 [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj b04d69c [Davies Liu] address comments 4859b80 [Davies Liu] fix comments f38011c [Davies Liu] add a test for grouping by decimal d2cf7ab [Davies Liu] add more tests for null checking 71983c5 [Davies Liu] add test for timestamp e8a1649 [Davies Liu] reuse buffer for string 39f09ca [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_obj 035501e [Davies Liu] fix style 236d6de [Davies Liu] support arbitrary object in UnsafeRow --- .../UnsafeFixedWidthAggregationMap.java | 144 ++++++------ .../sql/catalyst/expressions/UnsafeRow.java | 218 +++++++++--------- .../spark/sql/catalyst/util/ObjectPool.java | 78 +++++++ .../sql/catalyst/util/UniqueObjectPool.java | 59 +++++ .../spark/sql/catalyst/InternalRow.scala | 5 +- .../expressions/UnsafeRowConverter.scala | 94 +++----- .../UnsafeFixedWidthAggregationMapSuite.scala | 65 ++++-- .../expressions/UnsafeRowConverterSuite.scala | 190 +++++++++++---- .../sql/catalyst/util/ObjectPoolSuite.scala | 57 +++++ .../sql/execution/GeneratedAggregate.scala | 16 +- 10 files changed, 615 insertions(+), 311 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java index 83f2a312972fb..1e79f4b2e88e5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMap.java @@ -19,9 +19,11 @@ import java.util.Iterator; +import scala.Function1; + import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.catalyst.util.ObjectPool; +import org.apache.spark.sql.catalyst.util.UniqueObjectPool; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.map.BytesToBytesMap; import org.apache.spark.unsafe.memory.MemoryLocation; @@ -38,26 +40,48 @@ public final class UnsafeFixedWidthAggregationMap { * An empty aggregation buffer, encoded in UnsafeRow format. When inserting a new key into the * map, we copy this buffer and use it as the value. */ - private final byte[] emptyAggregationBuffer; + private final byte[] emptyBuffer; - private final StructType aggregationBufferSchema; + /** + * An empty row used by `initProjection` + */ + private static final InternalRow emptyRow = new GenericInternalRow(); - private final StructType groupingKeySchema; + /** + * Whether can the empty aggregation buffer be reuse without calling `initProjection` or not. + */ + private final boolean reuseEmptyBuffer; /** - * Encodes grouping keys as UnsafeRows. + * The projection used to initialize the emptyBuffer */ - private final UnsafeRowConverter groupingKeyToUnsafeRowConverter; + private final Function1 initProjection; + + /** + * Encodes grouping keys or buffers as UnsafeRows. + */ + private final UnsafeRowConverter keyConverter; + private final UnsafeRowConverter bufferConverter; /** * A hashmap which maps from opaque bytearray keys to bytearray values. */ private final BytesToBytesMap map; + /** + * An object pool for objects that are used in grouping keys. + */ + private final UniqueObjectPool keyPool; + + /** + * An object pool for objects that are used in aggregation buffers. + */ + private final ObjectPool bufferPool; + /** * Re-used pointer to the current aggregation buffer */ - private final UnsafeRow currentAggregationBuffer = new UnsafeRow(); + private final UnsafeRow currentBuffer = new UnsafeRow(); /** * Scratch space that is used when encoding grouping keys into UnsafeRow format. @@ -69,68 +93,39 @@ public final class UnsafeFixedWidthAggregationMap { private final boolean enablePerfMetrics; - /** - * @return true if UnsafeFixedWidthAggregationMap supports grouping keys with the given schema, - * false otherwise. - */ - public static boolean supportsGroupKeySchema(StructType schema) { - for (StructField field: schema.fields()) { - if (!UnsafeRow.readableFieldTypes.contains(field.dataType())) { - return false; - } - } - return true; - } - - /** - * @return true if UnsafeFixedWidthAggregationMap supports aggregation buffers with the given - * schema, false otherwise. - */ - public static boolean supportsAggregationBufferSchema(StructType schema) { - for (StructField field: schema.fields()) { - if (!UnsafeRow.settableFieldTypes.contains(field.dataType())) { - return false; - } - } - return true; - } - /** * Create a new UnsafeFixedWidthAggregationMap. * - * @param emptyAggregationBuffer the default value for new keys (a "zero" of the agg. function) - * @param aggregationBufferSchema the schema of the aggregation buffer, used for row conversion. - * @param groupingKeySchema the schema of the grouping key, used for row conversion. + * @param initProjection the default value for new keys (a "zero" of the agg. function) + * @param keyConverter the converter of the grouping key, used for row conversion. + * @param bufferConverter the converter of the aggregation buffer, used for row conversion. * @param memoryManager the memory manager used to allocate our Unsafe memory structures. * @param initialCapacity the initial capacity of the map (a sizing hint to avoid re-hashing). * @param enablePerfMetrics if true, performance metrics will be recorded (has minor perf impact) */ public UnsafeFixedWidthAggregationMap( - InternalRow emptyAggregationBuffer, - StructType aggregationBufferSchema, - StructType groupingKeySchema, + Function1 initProjection, + UnsafeRowConverter keyConverter, + UnsafeRowConverter bufferConverter, TaskMemoryManager memoryManager, int initialCapacity, boolean enablePerfMetrics) { - this.emptyAggregationBuffer = - convertToUnsafeRow(emptyAggregationBuffer, aggregationBufferSchema); - this.aggregationBufferSchema = aggregationBufferSchema; - this.groupingKeyToUnsafeRowConverter = new UnsafeRowConverter(groupingKeySchema); - this.groupingKeySchema = groupingKeySchema; - this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); + this.initProjection = initProjection; + this.keyConverter = keyConverter; + this.bufferConverter = bufferConverter; this.enablePerfMetrics = enablePerfMetrics; - } - /** - * Convert a Java object row into an UnsafeRow, allocating it into a new byte array. - */ - private static byte[] convertToUnsafeRow(InternalRow javaRow, StructType schema) { - final UnsafeRowConverter converter = new UnsafeRowConverter(schema); - final byte[] unsafeRow = new byte[converter.getSizeRequirement(javaRow)]; - final int writtenLength = - converter.writeRow(javaRow, unsafeRow, PlatformDependent.BYTE_ARRAY_OFFSET); - assert (writtenLength == unsafeRow.length): "Size requirement calculation was wrong!"; - return unsafeRow; + this.map = new BytesToBytesMap(memoryManager, initialCapacity, enablePerfMetrics); + this.keyPool = new UniqueObjectPool(100); + this.bufferPool = new ObjectPool(initialCapacity); + + InternalRow initRow = initProjection.apply(emptyRow); + this.emptyBuffer = new byte[bufferConverter.getSizeRequirement(initRow)]; + int writtenLength = bufferConverter.writeRow( + initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, bufferPool); + assert (writtenLength == emptyBuffer.length): "Size requirement calculation was wrong!"; + // re-use the empty buffer only when there is no object saved in pool. + reuseEmptyBuffer = bufferPool.size() == 0; } /** @@ -138,15 +133,16 @@ private static byte[] convertToUnsafeRow(InternalRow javaRow, StructType schema) * return the same object. */ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { - final int groupingKeySize = groupingKeyToUnsafeRowConverter.getSizeRequirement(groupingKey); + final int groupingKeySize = keyConverter.getSizeRequirement(groupingKey); // Make sure that the buffer is large enough to hold the key. If it's not, grow it: if (groupingKeySize > groupingKeyConversionScratchSpace.length) { groupingKeyConversionScratchSpace = new byte[groupingKeySize]; } - final int actualGroupingKeySize = groupingKeyToUnsafeRowConverter.writeRow( + final int actualGroupingKeySize = keyConverter.writeRow( groupingKey, groupingKeyConversionScratchSpace, - PlatformDependent.BYTE_ARRAY_OFFSET); + PlatformDependent.BYTE_ARRAY_OFFSET, + keyPool); assert (groupingKeySize == actualGroupingKeySize) : "Size requirement calculation was wrong!"; // Probe our map using the serialized key @@ -157,25 +153,31 @@ public UnsafeRow getAggregationBuffer(InternalRow groupingKey) { if (!loc.isDefined()) { // This is the first time that we've seen this grouping key, so we'll insert a copy of the // empty aggregation buffer into the map: + if (!reuseEmptyBuffer) { + // There is some objects referenced by emptyBuffer, so generate a new one + InternalRow initRow = initProjection.apply(emptyRow); + bufferConverter.writeRow(initRow, emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, + bufferPool); + } loc.putNewKey( groupingKeyConversionScratchSpace, PlatformDependent.BYTE_ARRAY_OFFSET, groupingKeySize, - emptyAggregationBuffer, + emptyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, - emptyAggregationBuffer.length + emptyBuffer.length ); } // Reset the pointer to point to the value that we just stored or looked up: final MemoryLocation address = loc.getValueAddress(); - currentAggregationBuffer.pointTo( + currentBuffer.pointTo( address.getBaseObject(), address.getBaseOffset(), - aggregationBufferSchema.length(), - aggregationBufferSchema + bufferConverter.numFields(), + bufferPool ); - return currentAggregationBuffer; + return currentBuffer; } /** @@ -211,14 +213,14 @@ public MapEntry next() { entry.key.pointTo( keyAddress.getBaseObject(), keyAddress.getBaseOffset(), - groupingKeySchema.length(), - groupingKeySchema + keyConverter.numFields(), + keyPool ); entry.value.pointTo( valueAddress.getBaseObject(), valueAddress.getBaseOffset(), - aggregationBufferSchema.length(), - aggregationBufferSchema + bufferConverter.numFields(), + bufferPool ); return entry; } @@ -246,6 +248,8 @@ public void printPerfMetrics() { System.out.println("Number of hash collisions: " + map.getNumHashCollisions()); System.out.println("Time spent resizing (ns): " + map.getTimeSpentResizingNs()); System.out.println("Total memory consumption (bytes): " + map.getTotalMemoryConsumption()); + System.out.println("Number of unique objects in keys: " + keyPool.size()); + System.out.println("Number of objects in buffers: " + bufferPool.size()); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 11d51d90f1802..f077064a02ec0 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,20 +17,12 @@ package org.apache.spark.sql.catalyst.expressions; -import javax.annotation.Nullable; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.Set; - import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.catalyst.util.ObjectPool; import org.apache.spark.unsafe.PlatformDependent; import org.apache.spark.unsafe.bitset.BitSetMethods; import org.apache.spark.unsafe.types.UTF8String; -import static org.apache.spark.sql.types.DataTypes.*; /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. @@ -44,7 +36,20 @@ * primitive types, such as long, double, or int, we store the value directly in the word. For * fields with non-primitive or variable-length values, we store a relative offset (w.r.t. the * base address of the row) that points to the beginning of the variable-length field, and length - * (they are combined into a long). + * (they are combined into a long). For other objects, they are stored in a pool, the indexes of + * them are hold in the the word. + * + * In order to support fast hashing and equality checks for UnsafeRows that contain objects + * when used as grouping key in BytesToBytesMap, we put the objects in an UniqueObjectPool to make + * sure all the key have the same index for same object, then we can hash/compare the objects by + * hash/compare the index. + * + * For non-primitive types, the word of a field could be: + * UNION { + * [1] [offset: 31bits] [length: 31bits] // StringType + * [0] [offset: 31bits] [length: 31bits] // BinaryType + * - [index: 63bits] // StringType, Binary, index to object in pool + * } * * Instances of `UnsafeRow` act as pointers to row data stored in this format. */ @@ -53,8 +58,12 @@ public final class UnsafeRow extends MutableRow { private Object baseObject; private long baseOffset; + /** A pool to hold non-primitive objects */ + private ObjectPool pool; + Object getBaseObject() { return baseObject; } long getBaseOffset() { return baseOffset; } + ObjectPool getPool() { return pool; } /** The number of fields in this row, used for calculating the bitset width (and in assertions) */ private int numFields; @@ -63,15 +72,6 @@ public final class UnsafeRow extends MutableRow { /** The width of the null tracking bit set, in bytes */ private int bitSetWidthInBytes; - /** - * This optional schema is required if you want to call generic get() and set() methods on - * this UnsafeRow, but is optional if callers will only use type-specific getTYPE() and setTYPE() - * methods. This should be removed after the planned InternalRow / Row split; right now, it's only - * needed by the generic get() method, which is only called internally by code that accesses - * UTF8String-typed columns. - */ - @Nullable - private StructType schema; private long getFieldOffset(int ordinal) { return baseOffset + bitSetWidthInBytes + ordinal * 8L; @@ -81,42 +81,7 @@ public static int calculateBitSetWidthInBytes(int numFields) { return ((numFields / 64) + (numFields % 64 == 0 ? 0 : 1)) * 8; } - /** - * Field types that can be updated in place in UnsafeRows (e.g. we support set() for these types) - */ - public static final Set settableFieldTypes; - - /** - * Fields types can be read(but not set (e.g. set() will throw UnsupportedOperationException). - */ - public static final Set readableFieldTypes; - - // TODO: support DecimalType - static { - settableFieldTypes = Collections.unmodifiableSet( - new HashSet( - Arrays.asList(new DataType[] { - NullType, - BooleanType, - ByteType, - ShortType, - IntegerType, - LongType, - FloatType, - DoubleType, - DateType, - TimestampType - }))); - - // We support get() on a superset of the types for which we support set(): - final Set _readableFieldTypes = new HashSet( - Arrays.asList(new DataType[]{ - StringType, - BinaryType - })); - _readableFieldTypes.addAll(settableFieldTypes); - readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); - } + public static final long OFFSET_BITS = 31L; /** * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called, @@ -130,22 +95,15 @@ public UnsafeRow() { } * @param baseObject the base object * @param baseOffset the offset within the base object * @param numFields the number of fields in this row - * @param schema an optional schema; this is necessary if you want to call generic get() or set() - * methods on this row, but is optional if the caller will only use type-specific - * getTYPE() and setTYPE() methods. + * @param pool the object pool to hold arbitrary objects */ - public void pointTo( - Object baseObject, - long baseOffset, - int numFields, - @Nullable StructType schema) { + public void pointTo(Object baseObject, long baseOffset, int numFields, ObjectPool pool) { assert numFields >= 0 : "numFields should >= 0"; - assert schema == null || schema.fields().length == numFields; this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); this.baseObject = baseObject; this.baseOffset = baseOffset; this.numFields = numFields; - this.schema = schema; + this.pool = pool; } private void assertIndexIsValid(int index) { @@ -168,9 +126,68 @@ private void setNotNullAt(int i) { BitSetMethods.unset(baseObject, baseOffset, i); } + /** + * Updates the column `i` as Object `value`, which cannot be primitive types. + */ @Override - public void update(int ordinal, Object value) { - throw new UnsupportedOperationException(); + public void update(int i, Object value) { + if (value == null) { + if (!isNullAt(i)) { + // remove the old value from pool + long idx = getLong(i); + if (idx <= 0) { + // this is the index of old value in pool, remove it + pool.replace((int)-idx, null); + } else { + // there will be some garbage left (UTF8String or byte[]) + } + setNullAt(i); + } + return; + } + + if (isNullAt(i)) { + // there is not an old value, put the new value into pool + int idx = pool.put(value); + setLong(i, (long)-idx); + } else { + // there is an old value, check the type, then replace it or update it + long v = getLong(i); + if (v <= 0) { + // it's the index in the pool, replace old value with new one + int idx = (int)-v; + pool.replace(idx, value); + } else { + // old value is UTF8String or byte[], try to reuse the space + boolean isString; + byte[] newBytes; + if (value instanceof UTF8String) { + newBytes = ((UTF8String) value).getBytes(); + isString = true; + } else { + newBytes = (byte[]) value; + isString = false; + } + int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE); + int oldLength = (int) (v & Integer.MAX_VALUE); + if (newBytes.length <= oldLength) { + // the new value can fit in the old buffer, re-use it + PlatformDependent.copyMemory( + newBytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + baseObject, + baseOffset + offset, + newBytes.length); + long flag = isString ? 1L << (OFFSET_BITS * 2) : 0L; + setLong(i, flag | (((long) offset) << OFFSET_BITS) | (long) newBytes.length); + } else { + // Cannot fit in the buffer + int idx = pool.put(value); + setLong(i, (long) -idx); + } + } + } + setNotNullAt(i); } @Override @@ -227,28 +244,38 @@ public int size() { return numFields; } - @Override - public StructType schema() { - return schema; - } - + /** + * Returns the object for column `i`, which should not be primitive type. + */ @Override public Object get(int i) { assertIndexIsValid(i); - assert (schema != null) : "Schema must be defined when calling generic get() method"; - final DataType dataType = schema.fields()[i].dataType(); - // UnsafeRow is only designed to be invoked by internal code, which only invokes this generic - // get() method when trying to access UTF8String-typed columns. If we refactor the codebase to - // separate the internal and external row interfaces, then internal code can fetch strings via - // a new getUTF8String() method and we'll be able to remove this method. if (isNullAt(i)) { return null; - } else if (dataType == StringType) { - return getUTF8String(i); - } else if (dataType == BinaryType) { - return getBinary(i); + } + long v = PlatformDependent.UNSAFE.getLong(baseObject, getFieldOffset(i)); + if (v <= 0) { + // It's an index to object in the pool. + int idx = (int)-v; + return pool.get(idx); } else { - throw new UnsupportedOperationException(); + // The column could be StingType or BinaryType + boolean isString = (v >> (OFFSET_BITS * 2)) > 0; + int offset = (int) ((v >> OFFSET_BITS) & Integer.MAX_VALUE); + int size = (int) (v & Integer.MAX_VALUE); + final byte[] bytes = new byte[size]; + PlatformDependent.copyMemory( + baseObject, + baseOffset + offset, + bytes, + PlatformDependent.BYTE_ARRAY_OFFSET, + size + ); + if (isString) { + return UTF8String.fromBytes(bytes); + } else { + return bytes; + } } } @@ -308,31 +335,6 @@ public double getDouble(int i) { } } - public UTF8String getUTF8String(int i) { - return UTF8String.fromBytes(getBinary(i)); - } - - public byte[] getBinary(int i) { - assertIndexIsValid(i); - final long offsetAndSize = getLong(i); - final int offset = (int)(offsetAndSize >> 32); - final int size = (int)(offsetAndSize & ((1L << 32) - 1)); - final byte[] bytes = new byte[size]; - PlatformDependent.copyMemory( - baseObject, - baseOffset + offset, - bytes, - PlatformDependent.BYTE_ARRAY_OFFSET, - size - ); - return bytes; - } - - @Override - public String getString(int i) { - return getUTF8String(i).toString(); - } - @Override public InternalRow copy() { throw new UnsupportedOperationException(); diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java new file mode 100644 index 0000000000000..97f89a7d0b758 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/ObjectPool.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util; + +/** + * A object pool stores a collection of objects in array, then they can be referenced by the + * pool plus an index. + */ +public class ObjectPool { + + /** + * An array to hold objects, which will grow as needed. + */ + private Object[] objects; + + /** + * How many objects in the pool. + */ + private int numObj; + + public ObjectPool(int capacity) { + objects = new Object[capacity]; + numObj = 0; + } + + /** + * Returns how many objects in the pool. + */ + public int size() { + return numObj; + } + + /** + * Returns the object at position `idx` in the array. + */ + public Object get(int idx) { + assert (idx < numObj); + return objects[idx]; + } + + /** + * Puts an object `obj` at the end of array, returns the index of it. + *

+ * The array will grow as needed. + */ + public int put(Object obj) { + if (numObj >= objects.length) { + Object[] tmp = new Object[objects.length * 2]; + System.arraycopy(objects, 0, tmp, 0, objects.length); + objects = tmp; + } + objects[numObj++] = obj; + return numObj - 1; + } + + /** + * Replaces the object at `idx` with new one `obj`. + */ + public void replace(int idx, Object obj) { + assert (idx < numObj); + objects[idx] = obj; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java new file mode 100644 index 0000000000000..d512392dcaacc --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/util/UniqueObjectPool.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util; + +import java.util.HashMap; + +/** + * An unique object pool stores a collection of unique objects in it. + */ +public class UniqueObjectPool extends ObjectPool { + + /** + * A hash map from objects to their indexes in the array. + */ + private HashMap objIndex; + + public UniqueObjectPool(int capacity) { + super(capacity); + objIndex = new HashMap(); + } + + /** + * Put an object `obj` into the pool. If there is an existing object equals to `obj`, it will + * return the index of the existing one. + */ + @Override + public int put(Object obj) { + if (objIndex.containsKey(obj)) { + return objIndex.get(obj); + } else { + int idx = super.put(obj); + objIndex.put(obj, idx); + return idx; + } + } + + /** + * The objects can not be replaced. + */ + @Override + public void replace(int idx, Object obj) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 61a29c89d8df3..57de0f26a9720 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -28,7 +28,10 @@ import org.apache.spark.unsafe.types.UTF8String abstract class InternalRow extends Row { // This is only use for test - override def getString(i: Int): String = getAs[UTF8String](i).toString + override def getString(i: Int): String = { + val str = getAs[UTF8String](i) + if (str != null) str.toString else null + } // These expensive API should not be used internally. final override def getDecimal(i: Int): java.math.BigDecimal = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index b61d490429e4f..b11fc245c4af9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.util.ObjectPool import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods @@ -33,6 +34,8 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { this(schema.fields.map(_.dataType)) } + def numFields: Int = fieldTypes.length + /** Re-used pointer to the unsafe row being written */ private[this] val unsafeRow = new UnsafeRow() @@ -68,8 +71,8 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { * @param baseOffset the base offset of the destination address * @return the number of bytes written. This should be equal to `getSizeRequirement(row)`. */ - def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long): Int = { - unsafeRow.pointTo(baseObject, baseOffset, writers.length, null) + def writeRow(row: InternalRow, baseObject: Object, baseOffset: Long, pool: ObjectPool): Int = { + unsafeRow.pointTo(baseObject, baseOffset, writers.length, pool) if (writers.length > 0) { // zero-out the bitset @@ -84,16 +87,16 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) { } var fieldNumber = 0 - var appendCursor: Int = fixedLengthSize + var cursor: Int = fixedLengthSize while (fieldNumber < writers.length) { if (row.isNullAt(fieldNumber)) { unsafeRow.setNullAt(fieldNumber) } else { - appendCursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, appendCursor) + cursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, cursor) } fieldNumber += 1 } - appendCursor + cursor } } @@ -108,11 +111,11 @@ private abstract class UnsafeColumnWriter { * @param source the row being converted * @param target a pointer to the converted unsafe row * @param column the column to write - * @param appendCursor the offset from the start of the unsafe row to the end of the row; + * @param cursor the offset from the start of the unsafe row to the end of the row; * used for calculating where variable-length data should be written * @return the number of variable-length bytes written */ - def write(source: InternalRow, target: UnsafeRow, column: Int, appendCursor: Int): Int + def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int /** * Return the number of bytes that are needed to write this variable-length value. @@ -134,8 +137,7 @@ private object UnsafeColumnWriter { case DoubleType => DoubleUnsafeColumnWriter case StringType => StringUnsafeColumnWriter case BinaryType => BinaryUnsafeColumnWriter - case t => - throw new UnsupportedOperationException(s"Do not know how to write columns of type $t") + case t => ObjectUnsafeColumnWriter } } } @@ -152,6 +154,7 @@ private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter private object BinaryUnsafeColumnWriter extends BinaryUnsafeColumnWriter +private object ObjectUnsafeColumnWriter extends ObjectUnsafeColumnWriter private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { // Primitives don't write to the variable-length region: @@ -159,88 +162,56 @@ private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { } private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setNullAt(column) 0 } } private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setBoolean(column, source.getBoolean(column)) 0 } } private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setByte(column, source.getByte(column)) 0 } } private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setShort(column, source.getShort(column)) 0 } } private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setInt(column, source.getInt(column)) 0 } } private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setLong(column, source.getLong(column)) 0 } } private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setFloat(column, source.getFloat(column)) 0 } } private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter { - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { target.setDouble(column, source.getDouble(column)) 0 } @@ -255,12 +226,10 @@ private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } - override def write( - source: InternalRow, - target: UnsafeRow, - column: Int, - appendCursor: Int): Int = { - val offset = target.getBaseOffset + appendCursor + protected[this] def isString: Boolean + + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { + val offset = target.getBaseOffset + cursor val bytes = getBytes(source, column) val numBytes = bytes.length if ((numBytes & 0x07) > 0) { @@ -274,19 +243,32 @@ private abstract class BytesUnsafeColumnWriter extends UnsafeColumnWriter { offset, numBytes ) - target.setLong(column, (appendCursor.toLong << 32L) | numBytes.toLong) + val flag = if (isString) 1L << (UnsafeRow.OFFSET_BITS * 2) else 0 + target.setLong(column, flag | (cursor.toLong << UnsafeRow.OFFSET_BITS) | numBytes.toLong) ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } } private class StringUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { + protected[this] def isString: Boolean = true def getBytes(source: InternalRow, column: Int): Array[Byte] = { source.getAs[UTF8String](column).getBytes } } private class BinaryUnsafeColumnWriter private() extends BytesUnsafeColumnWriter { + protected[this] def isString: Boolean = false def getBytes(source: InternalRow, column: Int): Array[Byte] = { source.getAs[Array[Byte]](column) } } + +private class ObjectUnsafeColumnWriter private() extends UnsafeColumnWriter { + def getSize(sourceRow: InternalRow, column: Int): Int = 0 + override def write(source: InternalRow, target: UnsafeRow, column: Int, cursor: Int): Int = { + val obj = source.get(column) + val idx = target.getPool.put(obj) + target.setLong(column, - idx) + 0 + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala index 3095ccb77761b..6fafc2f86684c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeFixedWidthAggregationMapSuite.scala @@ -23,8 +23,9 @@ import scala.util.Random import org.scalatest.{BeforeAndAfterEach, Matchers} import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateProjection import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, TaskMemoryManager, MemoryAllocator} +import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager} import org.apache.spark.unsafe.types.UTF8String @@ -33,10 +34,10 @@ class UnsafeFixedWidthAggregationMapSuite with Matchers with BeforeAndAfterEach { - import UnsafeFixedWidthAggregationMap._ - private val groupKeySchema = StructType(StructField("product", StringType) :: Nil) private val aggBufferSchema = StructType(StructField("salePrice", IntegerType) :: Nil) + private def emptyProjection: Projection = + GenerateProjection.generate(Seq(Literal(0)), Seq(AttributeReference("price", IntegerType)())) private def emptyAggregationBuffer: InternalRow = InternalRow(0) private var memoryManager: TaskMemoryManager = null @@ -52,21 +53,11 @@ class UnsafeFixedWidthAggregationMapSuite } } - test("supported schemas") { - assert(!supportsAggregationBufferSchema(StructType(StructField("x", StringType) :: Nil))) - assert(supportsGroupKeySchema(StructType(StructField("x", StringType) :: Nil))) - - assert( - !supportsAggregationBufferSchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) - assert( - !supportsGroupKeySchema(StructType(StructField("x", ArrayType(IntegerType)) :: Nil))) - } - test("empty map") { val map = new UnsafeFixedWidthAggregationMap( - emptyAggregationBuffer, - aggBufferSchema, - groupKeySchema, + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), memoryManager, 1024, // initial capacity false // disable perf metrics @@ -77,9 +68,9 @@ class UnsafeFixedWidthAggregationMapSuite test("updating values for a single key") { val map = new UnsafeFixedWidthAggregationMap( - emptyAggregationBuffer, - aggBufferSchema, - groupKeySchema, + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), memoryManager, 1024, // initial capacity false // disable perf metrics @@ -103,9 +94,9 @@ class UnsafeFixedWidthAggregationMapSuite test("inserting large random keys") { val map = new UnsafeFixedWidthAggregationMap( - emptyAggregationBuffer, - aggBufferSchema, - groupKeySchema, + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), memoryManager, 128, // initial capacity false // disable perf metrics @@ -120,6 +111,36 @@ class UnsafeFixedWidthAggregationMapSuite }.toSet seenKeys.size should be (groupKeys.size) seenKeys should be (groupKeys) + + map.free() + } + + test("with decimal in the key and values") { + val groupKeySchema = StructType(StructField("price", DecimalType(10, 0)) :: Nil) + val aggBufferSchema = StructType(StructField("amount", DecimalType.Unlimited) :: Nil) + val emptyProjection = GenerateProjection.generate(Seq(Literal(Decimal(0))), + Seq(AttributeReference("price", DecimalType.Unlimited)())) + val map = new UnsafeFixedWidthAggregationMap( + emptyProjection, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggBufferSchema), + memoryManager, + 1, // initial capacity + false // disable perf metrics + ) + + (0 until 100).foreach { i => + val groupKey = InternalRow(Decimal(i % 10)) + val row = map.getAggregationBuffer(groupKey) + row.update(0, Decimal(i)) + } + val seenKeys: Set[Int] = map.iterator().asScala.map { entry => + entry.key.getAs[Decimal](0).toInt + }.toSet + seenKeys.size should be (10) + seenKeys should be ((0 until 10).toSet) + + map.free() } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index c0675f4f4dff6..94c2f3242b122 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -23,10 +23,11 @@ import java.util.Arrays import org.scalatest.Matchers import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{ObjectPool, DateTimeUtils} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.PlatformDependent import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.unsafe.types.UTF8String class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { @@ -40,16 +41,21 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row.setInt(2, 2) val sizeRequired: Int = converter.getSizeRequirement(row) - sizeRequired should be (8 + (3 * 8)) + assert(sizeRequired === 8 + (3 * 8)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) - numBytesWritten should be (sizeRequired) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) - unsafeRow.getLong(0) should be (0) - unsafeRow.getLong(1) should be (1) - unsafeRow.getInt(2) should be (2) + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.getLong(1) === 1) + assert(unsafeRow.getInt(2) === 2) + + unsafeRow.setLong(1, 3) + assert(unsafeRow.getLong(1) === 3) + unsafeRow.setInt(2, 4) + assert(unsafeRow.getInt(2) === 4) } test("basic conversion with primitive, string and binary types") { @@ -58,22 +64,67 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) - row.setString(1, "Hello") + row.update(1, UTF8String.fromString("Hello")) row.update(2, "World".getBytes) val sizeRequired: Int = converter.getSizeRequirement(row) - sizeRequired should be (8 + (8 * 3) + + assert(sizeRequired === 8 + (8 * 3) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length) + ByteArrayMethods.roundNumberOfBytesToNearestWord("World".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) - numBytesWritten should be (sizeRequired) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() - unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) - unsafeRow.getLong(0) should be (0) - unsafeRow.getString(1) should be ("Hello") - unsafeRow.getBinary(2) should be ("World".getBytes) + val pool = new ObjectPool(10) + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.getString(1) === "Hello") + assert(unsafeRow.get(2) === "World".getBytes) + + unsafeRow.update(1, UTF8String.fromString("World")) + assert(unsafeRow.getString(1) === "World") + assert(pool.size === 0) + unsafeRow.update(1, UTF8String.fromString("Hello World")) + assert(unsafeRow.getString(1) === "Hello World") + assert(pool.size === 1) + + unsafeRow.update(2, "World".getBytes) + assert(unsafeRow.get(2) === "World".getBytes) + assert(pool.size === 1) + unsafeRow.update(2, "Hello World".getBytes) + assert(unsafeRow.get(2) === "Hello World".getBytes) + assert(pool.size === 2) + } + + test("basic conversion with primitive, decimal and array") { + val fieldTypes: Array[DataType] = Array(LongType, DecimalType(10, 0), ArrayType(StringType)) + val converter = new UnsafeRowConverter(fieldTypes) + + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.update(1, Decimal(1)) + row.update(2, Array(2)) + + val pool = new ObjectPool(10) + val sizeRequired: Int = converter.getSizeRequirement(row) + assert(sizeRequired === 8 + (8 * 3)) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) + assert(numBytesWritten === sizeRequired) + assert(pool.size === 2) + + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.get(1) === Decimal(1)) + assert(unsafeRow.get(2) === Array(2)) + + unsafeRow.update(1, Decimal(2)) + assert(unsafeRow.get(1) === Decimal(2)) + unsafeRow.update(2, Array(3, 4)) + assert(unsafeRow.get(2) === Array(3, 4)) + assert(pool.size === 2) } test("basic conversion with primitive, string, date and timestamp types") { @@ -87,21 +138,27 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25"))) val sizeRequired: Int = converter.getSizeRequirement(row) - sizeRequired should be (8 + (8 * 4) + + assert(sizeRequired === 8 + (8 * 4) + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length)) val buffer: Array[Long] = new Array[Long](sizeRequired / 8) - val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) - numBytesWritten should be (sizeRequired) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) val unsafeRow = new UnsafeRow() unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) - unsafeRow.getLong(0) should be (0) - unsafeRow.getString(1) should be ("Hello") + assert(unsafeRow.getLong(0) === 0) + assert(unsafeRow.getString(1) === "Hello") // Date is represented as Int in unsafeRow - DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) should be (Date.valueOf("1970-01-01")) + assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("1970-01-01")) // Timestamp is represented as Long in unsafeRow DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be (Timestamp.valueOf("2015-05-08 08:10:25")) + + unsafeRow.setInt(2, DateTimeUtils.fromJavaDate(Date.valueOf("2015-06-22"))) + assert(DateTimeUtils.toJavaDate(unsafeRow.getInt(2)) === Date.valueOf("2015-06-22")) + unsafeRow.setLong(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-06-22 08:10:25"))) + DateTimeUtils.toJavaTimestamp(unsafeRow.getLong(3)) should be + (Timestamp.valueOf("2015-06-22 08:10:25")) } test("null handling") { @@ -113,7 +170,12 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { IntegerType, LongType, FloatType, - DoubleType) + DoubleType, + StringType, + BinaryType, + DecimalType.Unlimited, + ArrayType(IntegerType) + ) val converter = new UnsafeRowConverter(fieldTypes) val rowWithAllNullColumns: InternalRow = { @@ -127,8 +189,8 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val sizeRequired: Int = converter.getSizeRequirement(rowWithAllNullColumns) val createdFromNullBuffer: Array[Long] = new Array[Long](sizeRequired / 8) val numBytesWritten = converter.writeRow( - rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET) - numBytesWritten should be (sizeRequired) + rowWithAllNullColumns, createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, null) + assert(numBytesWritten === sizeRequired) val createdFromNull = new UnsafeRow() createdFromNull.pointTo( @@ -136,13 +198,17 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { for (i <- 0 to fieldTypes.length - 1) { assert(createdFromNull.isNullAt(i)) } - createdFromNull.getBoolean(1) should be (false) - createdFromNull.getByte(2) should be (0) - createdFromNull.getShort(3) should be (0) - createdFromNull.getInt(4) should be (0) - createdFromNull.getLong(5) should be (0) + assert(createdFromNull.getBoolean(1) === false) + assert(createdFromNull.getByte(2) === 0) + assert(createdFromNull.getShort(3) === 0) + assert(createdFromNull.getInt(4) === 0) + assert(createdFromNull.getLong(5) === 0) assert(java.lang.Float.isNaN(createdFromNull.getFloat(6))) - assert(java.lang.Double.isNaN(createdFromNull.getFloat(7))) + assert(java.lang.Double.isNaN(createdFromNull.getDouble(7))) + assert(createdFromNull.getString(8) === null) + assert(createdFromNull.get(9) === null) + assert(createdFromNull.get(10) === null) + assert(createdFromNull.get(11) === null) // If we have an UnsafeRow with columns that are initially non-null and we null out those // columns, then the serialized row representation should be identical to what we would get by @@ -157,28 +223,68 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { r.setLong(5, 500) r.setFloat(6, 600) r.setDouble(7, 700) + r.update(8, UTF8String.fromString("hello")) + r.update(9, "world".getBytes) + r.update(10, Decimal(10)) + r.update(11, Array(11)) r } - val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8) + val pool = new ObjectPool(1) + val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8 + 2) converter.writeRow( - rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET) + rowWithNoNullColumns, setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, pool) val setToNullAfterCreation = new UnsafeRow() setToNullAfterCreation.pointTo( - setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, pool) - setToNullAfterCreation.isNullAt(0) should be (rowWithNoNullColumns.isNullAt(0)) - setToNullAfterCreation.getBoolean(1) should be (rowWithNoNullColumns.getBoolean(1)) - setToNullAfterCreation.getByte(2) should be (rowWithNoNullColumns.getByte(2)) - setToNullAfterCreation.getShort(3) should be (rowWithNoNullColumns.getShort(3)) - setToNullAfterCreation.getInt(4) should be (rowWithNoNullColumns.getInt(4)) - setToNullAfterCreation.getLong(5) should be (rowWithNoNullColumns.getLong(5)) - setToNullAfterCreation.getFloat(6) should be (rowWithNoNullColumns.getFloat(6)) - setToNullAfterCreation.getDouble(7) should be (rowWithNoNullColumns.getDouble(7)) + assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) + assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) + assert(setToNullAfterCreation.getByte(2) === rowWithNoNullColumns.getByte(2)) + assert(setToNullAfterCreation.getShort(3) === rowWithNoNullColumns.getShort(3)) + assert(setToNullAfterCreation.getInt(4) === rowWithNoNullColumns.getInt(4)) + assert(setToNullAfterCreation.getLong(5) === rowWithNoNullColumns.getLong(5)) + assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6)) + assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) + assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) + assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) + assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) for (i <- 0 to fieldTypes.length - 1) { + if (i >= 8) { + setToNullAfterCreation.update(i, null) + } setToNullAfterCreation.setNullAt(i) } - assert(Arrays.equals(createdFromNullBuffer, setToNullAfterCreationBuffer)) + // There are some garbage left in the var-length area + assert(Arrays.equals(createdFromNullBuffer, + java.util.Arrays.copyOf(setToNullAfterCreationBuffer, sizeRequired / 8))) + + setToNullAfterCreation.setNullAt(0) + setToNullAfterCreation.setBoolean(1, false) + setToNullAfterCreation.setByte(2, 20) + setToNullAfterCreation.setShort(3, 30) + setToNullAfterCreation.setInt(4, 400) + setToNullAfterCreation.setLong(5, 500) + setToNullAfterCreation.setFloat(6, 600) + setToNullAfterCreation.setDouble(7, 700) + setToNullAfterCreation.update(8, UTF8String.fromString("hello")) + setToNullAfterCreation.update(9, "world".getBytes) + setToNullAfterCreation.update(10, Decimal(10)) + setToNullAfterCreation.update(11, Array(11)) + + assert(setToNullAfterCreation.isNullAt(0) === rowWithNoNullColumns.isNullAt(0)) + assert(setToNullAfterCreation.getBoolean(1) === rowWithNoNullColumns.getBoolean(1)) + assert(setToNullAfterCreation.getByte(2) === rowWithNoNullColumns.getByte(2)) + assert(setToNullAfterCreation.getShort(3) === rowWithNoNullColumns.getShort(3)) + assert(setToNullAfterCreation.getInt(4) === rowWithNoNullColumns.getInt(4)) + assert(setToNullAfterCreation.getLong(5) === rowWithNoNullColumns.getLong(5)) + assert(setToNullAfterCreation.getFloat(6) === rowWithNoNullColumns.getFloat(6)) + assert(setToNullAfterCreation.getDouble(7) === rowWithNoNullColumns.getDouble(7)) + assert(setToNullAfterCreation.getString(8) === rowWithNoNullColumns.getString(8)) + assert(setToNullAfterCreation.get(9) === rowWithNoNullColumns.get(9)) + assert(setToNullAfterCreation.get(10) === rowWithNoNullColumns.get(10)) + assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala new file mode 100644 index 0000000000000..94764df4b9cdb --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ObjectPoolSuite.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite + +class ObjectPoolSuite extends SparkFunSuite with Matchers { + + test("pool") { + val pool = new ObjectPool(1) + assert(pool.put(1) === 0) + assert(pool.put("hello") === 1) + assert(pool.put(false) === 2) + + assert(pool.get(0) === 1) + assert(pool.get(1) === "hello") + assert(pool.get(2) === false) + assert(pool.size() === 3) + + pool.replace(1, "world") + assert(pool.get(1) === "world") + assert(pool.size() === 3) + } + + test("unique pool") { + val pool = new UniqueObjectPool(1) + assert(pool.put(1) === 0) + assert(pool.put("hello") === 1) + assert(pool.put(1) === 0) + assert(pool.put("hello") === 1) + + assert(pool.get(0) === 1) + assert(pool.get(1) === "hello") + assert(pool.size() === 2) + + intercept[UnsupportedOperationException] { + pool.replace(1, "world") + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index ba2c8f53d702d..44930f82b53a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -238,11 +238,6 @@ case class GeneratedAggregate( StructType(fields) } - val schemaSupportsUnsafe: Boolean = { - UnsafeFixedWidthAggregationMap.supportsAggregationBufferSchema(aggregationBufferSchema) && - UnsafeFixedWidthAggregationMap.supportsGroupKeySchema(groupKeySchema) - } - child.execute().mapPartitions { iter => // Builds a new custom class for holding the results of aggregation for a group. val initialValues = computeFunctions.flatMap(_.initialValues) @@ -283,12 +278,12 @@ case class GeneratedAggregate( val resultProjection = resultProjectionBuilder() Iterator(resultProjection(buffer)) - } else if (unsafeEnabled && schemaSupportsUnsafe) { + } else if (unsafeEnabled) { log.info("Using Unsafe-based aggregator") val aggregationMap = new UnsafeFixedWidthAggregationMap( - newAggregationBuffer(EmptyRow), - aggregationBufferSchema, - groupKeySchema, + newAggregationBuffer, + new UnsafeRowConverter(groupKeySchema), + new UnsafeRowConverter(aggregationBufferSchema), TaskContext.get.taskMemoryManager(), 1024 * 16, // initial capacity false // disable tracking of performance metrics @@ -323,9 +318,6 @@ case class GeneratedAggregate( } } } else { - if (unsafeEnabled) { - log.info("Not using Unsafe-based aggregator because it is not supported for this schema") - } val buffers = new java.util.HashMap[InternalRow, MutableRow]() var currentRow: InternalRow = null From 4e880cf5967c0933e1d098a1d1f7db34b23ca8f8 Mon Sep 17 00:00:00 2001 From: Rosstin Date: Mon, 29 Jun 2015 16:09:29 -0700 Subject: [PATCH 060/122] [SPARK-8661][ML] for LinearRegressionSuite.scala, changed javadoc-style comments to regular multiline comments, to make copy-pasting R code more simple for mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala, changed javadoc-style comments to regular multiline comments, to make copy-pasting R code more simple Author: Rosstin Closes #7098 from Rosstin/SPARK-8661 and squashes the following commits: 5a05dee [Rosstin] SPARK-8661 for LinearRegressionSuite.scala, changed javadoc-style comments to regular multiline comments to make it easier to copy-paste the R code. bb9a4b1 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8660 242aedd [Rosstin] SPARK-8660, changed comment style from JavaDoc style to normal multiline comment in order to make copypaste into R easier, in file classification/LogisticRegressionSuite.scala 2cd2985 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8639 21ac1e5 [Rosstin] Merge branch 'master' of github.com:apache/spark into SPARK-8639 6c18058 [Rosstin] fixed minor typos in docs/README.md and docs/api.md --- .../ml/regression/LinearRegressionSuite.scala | 192 +++++++++--------- 1 file changed, 96 insertions(+), 96 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index ad1e9da692ee2..5f39d44f37352 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -28,26 +28,26 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { @transient var dataset: DataFrame = _ @transient var datasetWithoutIntercept: DataFrame = _ - /** - * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML - * is the same as the one trained by R's glmnet package. The following instruction - * describes how to reproduce the data in R. - * - * import org.apache.spark.mllib.util.LinearDataGenerator - * val data = - * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), - * Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2) - * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1) - * .saveAsTextFile("path") + /* + In `LinearRegressionSuite`, we will make sure that the model trained by SparkML + is the same as the one trained by R's glmnet package. The following instruction + describes how to reproduce the data in R. + + import org.apache.spark.mllib.util.LinearDataGenerator + val data = + sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), + Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2) + data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).coalesce(1) + .saveAsTextFile("path") */ override def beforeAll(): Unit = { super.beforeAll() dataset = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2)) - /** - * datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating - * training model without intercept + /* + datasetWithoutIntercept is not needed for correctness testing but is useful for illustrating + training model without intercept */ datasetWithoutIntercept = sqlContext.createDataFrame( sc.parallelize(LinearDataGenerator.generateLinearInput( @@ -59,20 +59,20 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = new LinearRegression val model = trainer.fit(dataset) - /** - * Using the following R code to load the data and train the model using glmnet package. - * - * library("glmnet") - * data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) - * features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) - * label <- as.numeric(data$V1) - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.300528 - * as.numeric.data.V2. 4.701024 - * as.numeric.data.V3. 7.198257 + /* + Using the following R code to load the data and train the model using glmnet package. + + library("glmnet") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3))) + label <- as.numeric(data$V1) + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.300528 + as.numeric.data.V2. 4.701024 + as.numeric.data.V3. 7.198257 */ val interceptR = 6.298698 val weightsR = Array(4.700706, 7.199082) @@ -94,29 +94,29 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val model = trainer.fit(dataset) val modelWithoutIntercept = trainer.fit(datasetWithoutIntercept) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0, - * intercept = FALSE)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * as.numeric.data.V2. 6.995908 - * as.numeric.data.V3. 5.275131 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0, + intercept = FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.995908 + as.numeric.data.V3. 5.275131 */ val weightsR = Array(6.995908, 5.275131) assert(model.intercept ~== 0 relTol 1E-3) assert(model.weights(0) ~== weightsR(0) relTol 1E-3) assert(model.weights(1) ~== weightsR(1) relTol 1E-3) - /** - * Then again with the data with no intercept: - * > weightsWithoutIntercept - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * as.numeric.data3.V2. 4.70011 - * as.numeric.data3.V3. 7.19943 + /* + Then again with the data with no intercept: + > weightsWithoutIntercept + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data3.V2. 4.70011 + as.numeric.data3.V3. 7.19943 */ val weightsWithoutInterceptR = Array(4.70011, 7.19943) @@ -129,14 +129,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.24300 - * as.numeric.data.V2. 4.024821 - * as.numeric.data.V3. 6.679841 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.24300 + as.numeric.data.V2. 4.024821 + as.numeric.data.V3. 6.679841 */ val interceptR = 6.24300 val weightsR = Array(4.024821, 6.679841) @@ -158,15 +158,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setFitIntercept(false) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, - * intercept=FALSE)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * as.numeric.data.V2. 6.299752 - * as.numeric.data.V3. 4.772913 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57, + intercept=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 6.299752 + as.numeric.data.V3. 4.772913 */ val interceptR = 0.0 val weightsR = Array(6.299752, 4.772913) @@ -187,14 +187,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.328062 - * as.numeric.data.V2. 3.222034 - * as.numeric.data.V3. 4.926260 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.328062 + as.numeric.data.V2. 3.222034 + as.numeric.data.V3. 4.926260 */ val interceptR = 5.269376 val weightsR = Array(3.736216, 5.712356) @@ -216,15 +216,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setFitIntercept(false) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, - * intercept = FALSE)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * as.numeric.data.V2. 5.522875 - * as.numeric.data.V3. 4.214502 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3, + intercept = FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.data.V2. 5.522875 + as.numeric.data.V3. 4.214502 */ val interceptR = 0.0 val weightsR = Array(5.522875, 4.214502) @@ -245,14 +245,14 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) 6.324108 - * as.numeric.data.V2. 3.168435 - * as.numeric.data.V3. 5.200403 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) 6.324108 + as.numeric.data.V2. 3.168435 + as.numeric.data.V3. 5.200403 */ val interceptR = 5.696056 val weightsR = Array(3.670489, 6.001122) @@ -274,15 +274,15 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { .setFitIntercept(false) val model = trainer.fit(dataset) - /** - * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, - * intercept=FALSE)) - * > weights - * 3 x 1 sparse Matrix of class "dgCMatrix" - * s0 - * (Intercept) . - * as.numeric.dataM.V2. 5.673348 - * as.numeric.dataM.V3. 4.322251 + /* + weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6, + intercept=FALSE)) + > weights + 3 x 1 sparse Matrix of class "dgCMatrix" + s0 + (Intercept) . + as.numeric.dataM.V2. 5.673348 + as.numeric.dataM.V3. 4.322251 */ val interceptR = 0.0 val weightsR = Array(5.673348, 4.322251) From 4b497a724a87ef24702c2df9ec6863ee57a87c1c Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 29 Jun 2015 16:26:05 -0700 Subject: [PATCH 061/122] [SPARK-8710] [SQL] Change ScalaReflection.mirror from a val to a def. jira: https://issues.apache.org/jira/browse/SPARK-8710 Author: Yin Huai Closes #7094 from yhuai/SPARK-8710 and squashes the following commits: c854baa [Yin Huai] Change ScalaReflection.mirror from a val to a def. --- .../org/apache/spark/sql/catalyst/ScalaReflection.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 90698cd572de4..21b1de1ab9cb1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -28,7 +28,11 @@ import org.apache.spark.sql.types._ */ object ScalaReflection extends ScalaReflection { val universe: scala.reflect.runtime.universe.type = scala.reflect.runtime.universe - val mirror: universe.Mirror = universe.runtimeMirror(Thread.currentThread().getContextClassLoader) + // Since we are creating a runtime mirror usign the class loader of current thread, + // we need to use def at here. So, every time we call mirror, it is using the + // class loader of the current thread. + override def mirror: universe.Mirror = + universe.runtimeMirror(Thread.currentThread().getContextClassLoader) } /** @@ -39,7 +43,7 @@ trait ScalaReflection { val universe: scala.reflect.api.Universe /** The mirror used to access types in the universe */ - val mirror: universe.Mirror + def mirror: universe.Mirror import universe._ From 881662e9c93893430756320f51cef0fc6643f681 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 29 Jun 2015 16:34:50 -0700 Subject: [PATCH 062/122] [SPARK-8589] [SQL] cleanup DateTimeUtils move date time related operations into `DateTimeUtils` and rename some methods to make it more clear. Author: Wenchen Fan Closes #6980 from cloud-fan/datetime and squashes the following commits: 9373a9d [Wenchen Fan] cleanup DateTimeUtil --- .../spark/sql/catalyst/expressions/Cast.scala | 43 ++---------- .../sql/catalyst/util/DateTimeUtils.scala | 70 +++++++++++++------ .../spark/sql/hive/hiveWriterContainers.scala | 2 +- 3 files changed, 58 insertions(+), 57 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8d66968a2fc35..d69d490ad666a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions import java.math.{BigDecimal => JavaBigDecimal} import java.sql.{Date, Timestamp} -import java.text.{DateFormat, SimpleDateFormat} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.analysis.TypeCheckResult @@ -122,9 +121,9 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w // UDFToString private[this] def castToString(from: DataType): Any => Any = from match { case BinaryType => buildCast[Array[Byte]](_, UTF8String.fromBytes) - case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.toString(d))) + case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d))) case TimestampType => buildCast[Long](_, - t => UTF8String.fromString(timestampToString(DateTimeUtils.toJavaTimestamp(t)))) + t => UTF8String.fromString(DateTimeUtils.timestampToString(t))) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -183,7 +182,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case ByteType => buildCast[Byte](_, b => longToTimestamp(b.toLong)) case DateType => - buildCast[Int](_, d => DateTimeUtils.toMillisSinceEpoch(d) * 10000) + buildCast[Int](_, d => DateTimeUtils.daysToMillis(d) * 10000) // TimestampWritable.decimalToTimestamp case DecimalType() => buildCast[Decimal](_, d => decimalToTimestamp(d)) @@ -216,18 +215,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w ts / 10000000.0 } - // Converts Timestamp to string according to Hive TimestampWritable convention - private[this] def timestampToString(ts: Timestamp): String = { - val timestampString = ts.toString - val formatted = Cast.threadLocalTimestampFormat.get.format(ts) - - if (timestampString.length > 19 && timestampString.substring(19) != ".0") { - formatted + timestampString.substring(19) - } else { - formatted - } - } - // DateConverter private[this] def castToDate(from: DataType): Any => Any = from match { case StringType => @@ -449,11 +436,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (DateType, StringType) => defineCodeGen(ctx, ev, c => s"""${ctx.stringType}.fromString( - org.apache.spark.sql.catalyst.util.DateTimeUtils.toString($c))""") - // Special handling required for timestamps in hive test cases since the toString function - // does not match the expected output. + org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""") case (TimestampType, StringType) => - super.genCode(ctx, ev) + defineCodeGen(ctx, ev, c => + s"""${ctx.stringType}.fromString( + org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""") case (_, StringType) => defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))") @@ -477,19 +464,3 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } } - -object Cast { - // `SimpleDateFormat` is not thread-safe. - private[sql] val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { - override def initialValue(): SimpleDateFormat = { - new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") - } - } - - // `SimpleDateFormat` is not thread-safe. - private[sql] val threadLocalDateFormat = new ThreadLocal[DateFormat] { - override def initialValue(): SimpleDateFormat = { - new SimpleDateFormat("yyyy-MM-dd") - } - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index ff79884a44d00..640e67e2ecd76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -18,11 +18,9 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} -import java.text.SimpleDateFormat +import java.text.{DateFormat, SimpleDateFormat} import java.util.{Calendar, TimeZone} -import org.apache.spark.sql.catalyst.expressions.Cast - /** * Helper functions for converting between internal and external date and time representations. * Dates are exposed externally as java.sql.Date and are represented internally as the number of @@ -41,35 +39,53 @@ object DateTimeUtils { // Java TimeZone has no mention of thread safety. Use thread local instance to be safe. - private val LOCAL_TIMEZONE = new ThreadLocal[TimeZone] { + private val threadLocalLocalTimeZone = new ThreadLocal[TimeZone] { override protected def initialValue: TimeZone = { Calendar.getInstance.getTimeZone } } - private def javaDateToDays(d: Date): Int = { - millisToDays(d.getTime) + // `SimpleDateFormat` is not thread-safe. + private val threadLocalTimestampFormat = new ThreadLocal[DateFormat] { + override def initialValue(): SimpleDateFormat = { + new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + } + } + + // `SimpleDateFormat` is not thread-safe. + private val threadLocalDateFormat = new ThreadLocal[DateFormat] { + override def initialValue(): SimpleDateFormat = { + new SimpleDateFormat("yyyy-MM-dd") + } } + // we should use the exact day as Int, for example, (year, month, day) -> day def millisToDays(millisLocal: Long): Int = { - ((millisLocal + LOCAL_TIMEZONE.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt + ((millisLocal + threadLocalLocalTimeZone.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt } - def toMillisSinceEpoch(days: Int): Long = { + // reverse of millisToDays + def daysToMillis(days: Int): Long = { val millisUtc = days.toLong * MILLIS_PER_DAY - millisUtc - LOCAL_TIMEZONE.get().getOffset(millisUtc) + millisUtc - threadLocalLocalTimeZone.get().getOffset(millisUtc) } - def fromJavaDate(date: Date): Int = { - javaDateToDays(date) - } + def dateToString(days: Int): String = + threadLocalDateFormat.get.format(toJavaDate(days)) - def toJavaDate(daysSinceEpoch: Int): Date = { - new Date(toMillisSinceEpoch(daysSinceEpoch)) - } + // Converts Timestamp to string according to Hive TimestampWritable convention. + def timestampToString(num100ns: Long): String = { + val ts = toJavaTimestamp(num100ns) + val timestampString = ts.toString + val formatted = threadLocalTimestampFormat.get.format(ts) - def toString(days: Int): String = Cast.threadLocalDateFormat.get.format(toJavaDate(days)) + if (timestampString.length > 19 && timestampString.substring(19) != ".0") { + formatted + timestampString.substring(19) + } else { + formatted + } + } def stringToTime(s: String): java.util.Date = { if (!s.contains('T')) { @@ -100,7 +116,21 @@ object DateTimeUtils { } /** - * Return a java.sql.Timestamp from number of 100ns since epoch + * Returns the number of days since epoch from from java.sql.Date. + */ + def fromJavaDate(date: Date): Int = { + millisToDays(date.getTime) + } + + /** + * Returns a java.sql.Date from number of days since epoch. + */ + def toJavaDate(daysSinceEpoch: Int): Date = { + new Date(daysToMillis(daysSinceEpoch)) + } + + /** + * Returns a java.sql.Timestamp from number of 100ns since epoch. */ def toJavaTimestamp(num100ns: Long): Timestamp = { // setNanos() will overwrite the millisecond part, so the milliseconds should be @@ -118,7 +148,7 @@ object DateTimeUtils { } /** - * Return the number of 100ns since epoch from java.sql.Timestamp. + * Returns the number of 100ns since epoch from java.sql.Timestamp. */ def fromJavaTimestamp(t: Timestamp): Long = { if (t != null) { @@ -129,7 +159,7 @@ object DateTimeUtils { } /** - * Return the number of 100ns (hundred of nanoseconds) since epoch from Julian day + * Returns the number of 100ns (hundred of nanoseconds) since epoch from Julian day * and nanoseconds in a day */ def fromJulianDay(day: Int, nanoseconds: Long): Long = { @@ -139,7 +169,7 @@ object DateTimeUtils { } /** - * Return Julian day and nanoseconds in a day from the number of 100ns (hundred of nanoseconds) + * Returns Julian day and nanoseconds in a day from the number of 100ns (hundred of nanoseconds) */ def toJulianDay(num100ns: Long): (Int, Long) = { val seconds = num100ns / HUNDRED_NANOS_PER_SECOND + SECONDS_PER_DAY / 2 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index ab75b12e2a2e7..ecc78a5f8d321 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -201,7 +201,7 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer( def convertToHiveRawString(col: String, value: Any): String = { val raw = String.valueOf(value) schema(col).dataType match { - case DateType => DateTimeUtils.toString(raw.toInt) + case DateType => DateTimeUtils.dateToString(raw.toInt) case _: DecimalType => BigDecimal(raw).toString() case _ => raw } From cec98525fd2b731cb78935bf7bc6c7963411744e Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 29 Jun 2015 17:19:05 -0700 Subject: [PATCH 063/122] [SPARK-8634] [STREAMING] [TESTS] Fix flaky test StreamingListenerSuite "receiver info reporting" As per the unit test log in https://amplab.cs.berkeley.edu/jenkins/job/SparkPullRequestBuilder/35754/ ``` 15/06/24 23:09:10.210 Thread-3495 INFO ReceiverTracker: Starting 1 receivers 15/06/24 23:09:10.270 Thread-3495 INFO SparkContext: Starting job: apply at Transformer.scala:22 ... 15/06/24 23:09:14.259 ForkJoinPool-4-worker-29 INFO StreamingListenerSuiteReceiver: Started receiver and sleeping 15/06/24 23:09:14.270 ForkJoinPool-4-worker-29 INFO StreamingListenerSuiteReceiver: Reporting error and sleeping ``` it needs at least 4 seconds to receive all receiver events in this slow machine, but `timeout` for `eventually` is only 2 seconds. This PR increases `timeout` to make this test stable. Author: zsxwing Closes #7017 from zsxwing/SPARK-8634 and squashes the following commits: 719cae4 [zsxwing] Fix flaky test StreamingListenerSuite "receiver info reporting" --- .../org/apache/spark/streaming/StreamingListenerSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 1dc8960d60528..7bc7727a9fbe4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -116,7 +116,7 @@ class StreamingListenerSuite extends TestSuiteBase with Matchers { ssc.start() try { - eventually(timeout(2000 millis), interval(20 millis)) { + eventually(timeout(30 seconds), interval(20 millis)) { collector.startedReceiverStreamIds.size should equal (1) collector.startedReceiverStreamIds(0) should equal (0) collector.stoppedReceiverStreamIds should have size 1 From fbf75738feddebb352d5cedf503b573105d4b7a7 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 29 Jun 2015 17:20:05 -0700 Subject: [PATCH 064/122] [SPARK-7287] [SPARK-8567] [TEST] Add sc.stop to applications in SparkSubmitSuite Hopefully, this suite will not be flaky anymore. Author: Yin Huai Closes #7027 from yhuai/SPARK-8567 and squashes the following commits: c0167e2 [Yin Huai] Add sc.stop(). --- .../spark/deploy/SparkSubmitSuite.scala | 2 ++ .../regression-test-SPARK-8489/Main.scala | 1 + .../regression-test-SPARK-8489/test.jar | Bin 6811 -> 6828 bytes 3 files changed, 3 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 357ed90be3f5c..2e05dec99b6bf 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -548,6 +548,7 @@ object JarCreationTest extends Logging { if (result.nonEmpty) { throw new Exception("Could not load user class from jar:\n" + result(0)) } + sc.stop() } } @@ -573,6 +574,7 @@ object SimpleApplicationTest { s"Master had $config=$masterValue but executor had $config=$executorValue") } } + sc.stop() } } diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala index e1715177e3f1b..0e428ba1d7456 100644 --- a/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala +++ b/sql/hive/src/test/resources/regression-test-SPARK-8489/Main.scala @@ -38,6 +38,7 @@ object Main { val df = hc.createDataFrame(Seq(MyCoolClass("1", "2", "3"))) df.collect() println("Regression test for SPARK-8489 success!") + sc.stop() } } diff --git a/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar b/sql/hive/src/test/resources/regression-test-SPARK-8489/test.jar index 4f59fba9eab558131b2587e51b7c2e2d54348bd1..5944aa6076a5fe7a8188c947fd6847c046614101 100644 GIT binary patch delta 1819 zcmY+Fdpy&N8^^!q&d6nHbB(cs!yGh4h#1w-VbXpXd4G^E~h8`RDoEK~QJrZcO7ZR4bX zovjUayHIFvIB4u~enzUU90>C?Xn`=q87-R_p_>v%U9o!iPEZU_oi;@Ykr^AW3Krfy zmcQC-)r+LqTpB$vO1<1}<6pm(VB^$y5NwOaqD4bQ%lmWmu_TxnX~D2i%w^e2>i zBOIKNgW?}+>8K5g4A!-qxZ$Kfc4x3WT+su)P9o8Af=%SWJ_dAhA>7`Spxwjw*L=tp z9;xt`sDXKE+gBR&q{>abkMI;lqA`#0&u&#K?nG~~2SL~xC$!ZoPG6>oxnO9b)VQi& zmifq4VngbH11qpDt`&zCx4@IG=F}QAi@I$SGpebd;+yvDk>Gqh^^Zr+q;E)Sv3G7< zko3JoJ{R~{d!D?tt*=?5 zjdP+JG;wKn??*dme5_x3F0Y#1zXD}L1jCYS(hO=-59%g zE|DI0pw6uQ+#7g=PC($_RmLU#d7X8H6M-h`If3=9>EzsO$7|4z{f~_9nBA9YHb{J_ zH;AnB{gkgED;+oHa?^f-XYFd9ih)l75ME~Vb-nkE0*gTj1gZr zmI6P`fni3aj^2urMdMXv_OVeSclGogc*Sfixun8dB%B@e;ocEqVN^kr(Y(%Hqgnr0 zO77z(l|MsGo-D2lwGQ^jFA#3aoyo0EBDQu)UgFF?DjR*9h3jz>wN@(iL8Gzm6Oo7%tZq;#+T# zrP|J!?{#grxk!)JZ|XD1&30Q;o4Hk|coLU}8p@aHUu;=Q2vEsne)5NF#5xd;WU8r* zE_dC^XVylao(UYVedTTOfTzCrYUpCr(Ul>0JJD_+{ez$OrfbrE(7)1Nf*`%iTeuQu zr{n6W=vbkV1j4V&+uAwQTmkkVIc1x&mcQ0~+&4k=6+{DP?lvD~= z@1okIWs-QA$NIkLJuau}Bv&O@r8V-p=6b-y$Qa$k1=B<3NjP)MY5d()&j{$#xgi-p z{bzWbyOr=KvA!kSVKDOK!h8n=xwz&c%Z_pzb6s2;JRZzDU;hA=c#t8zcQa#|Olh5q zVZB)#=fxz3HO!k#_g>DLI)x-@xWi~ZZQa%#iKTs_4zJ;)p`hQ;(85_<-|9~x)$3TU zX>^MHgbm**D`=hw8Q+wvO#gNBs0(kB)gAo~Z#$v=J2ss*#38DZ%2Lr$lZ@eyB zE(}E8e5hucwbeJ=UGdL@56FhwjB1aqh^3>;5-rH8%=+aRp7a2C9d;?&ljUHw^j7gg>qW}FNHfIq(Cen2?@--XoI+Q}z44Cx!{ z6Io{?A_I~U0zrWP+Y3zzW+(sEP8@2=8Ob}60B}nI0OV;QFjd+POyQf+c)?Ua@1-(E zAhZ-17{9Ci{z~cZU;yAI`yWpG8!r-z`2pte{m#liG%7-9zbNZ|43&mfq}Y~f zle*&8MHGq2V6AJcy6P&^K~<}moi4lk+4Y&{pZ9q`@9%w|_pk4lYr-;-cEy9l6aZi_ z7{Dcu-ji+s9WJ~tQvM(i=dk&phz6cI7O3wrm-=@?O(7CUZB1+m%(^(&nn+=G#?4k@ zcsnp1qW<-275~XT&qCw)0FRSV6c-*xIeyPr5X`poae?~!foeYw7YFQx1tkVy>IQaO zwwtH`SG)w+Kt$-x1p)xk;(){bv6)!Fu{w3wJNc6$HS`0+ss1rBY_HuIC=Q;gQ_L8W zl67<90px-N<#3Q;vwL@JkXjMVR7*X0PZw3X*}E2WA#-O|hxFv`4)M*;LOaWW&#S+? zJ$z{;X<0PX104+6jo&@!OeIdW91M3+|B&{xxDBoTbsVBoWdqhVdzu#}>CU^DvuL-v z)5s~u+Ax9C>06}1F|qrziwDkHyMoWm9fc0D^cm7<9c0&*RzmrpP1}!*f@jeS=m4Zx zsNkG#-5Ki#r864K=2^_O!t)7v)T=1sG5+TsXSKy#DFt+le1lIxL#S>vO4Bc(40S+WxssX?ShFq1rFS^&nB0 z-R9E+8jai{)p2MWccs+jNzioQwVK;G@2?J~J;ad3>x+t(A7#q*Cec26dak*9sFiT_ zm(1D>Go^pgsMEUXFbV$&&K~ob$uX4mo{bIs4Jdju$gF?_8f~~e>#Tgff12j$VbK3| z1WStw;>&6g_Qgv=us;}+qdcJ&tA-<-8>9f8=5L`n&$Z(3jCkFj){oZoT^mflnZn)@ zEML+&Wz%;W7H0Dl^PR@I1ioV~axv+RENl8sbXRvj2@F* z(8ss@z%ujdw5-`2)oDz!qh7Il9nPjve1!TcGFkRus!4+E>bLs^5-7Ls4@)Lr_FMB? zHghXb025;y!^^B21ms3Z+D2=Bxrv0;XrVm+zO zQinBNVMvM;VbLA-Ejx^64A%mf$MmA(8=3rSx*QxPckXWWyyZaa(yiJBQ;gnj@2a6? z*(>`#pYTsCc)y=XccxxlJC(qpp5g3%Xk(ChE=M!3`iJEf0#&zOn(nZ^k{?V8=qhP5SjoAQkmwvS@NTwg`VPB*`Yaf+) zkzvCZzNivX%T_`qlp}HnE*TA*eV@I&;xJLDkT7lDa{Kf!{_jbA>L!78mNTze#Su$y z34WRQx4ERU)O+znKIgn|f7mv#Hi?Xj*^4a^DWy{~@+A_+<_({XmpazumOAWc$_n3? z8_&=Rv@s6#dVj!0-}=%8+ssT?GPm;e0~=9~3&TkQ`8e&dRg} z;LyxvOxn=kTYeQ$5_26g{vRI-jx-u>J6s4ZmT;H5F>4FcBl4oV2y7R+jV77~T zwZ+7!MXE~s7#&afYSS$g9sV+32A=P4-Dcg3t6uazTeA>|T}OhpB&~0}yZx6>p+iE9 zFu_@8g3yv9@8+|5prb}yuMBDC6~BixIWF*uE=8v*y0*5u0yXj`l8rdv`Cymlr@QLr zgnG4wxn1+?-@S|ZnEp;G#c3w$y;WrN@0^{tmrzI8#=SgKaO)7+*^nb_p8zTV{$~>g z8ew-N3kDaV#K=a-FeVVtBdahXkV+~DvHyVlkmH9K|MSfVgc1;c^4kdW%Qt_k&;^^dG1;{1+ From 5d30eae56051c563a8427f330b09ef66db0a0d21 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Mon, 29 Jun 2015 17:21:35 -0700 Subject: [PATCH 065/122] [SPARK-8437] [DOCS] Using directory path without wildcard for filename slow for large number of files with wholeTextFiles and binaryFiles Note that 'dir/*' can be more efficient in some Hadoop FS implementations that 'dir/' Author: Sean Owen Closes #7036 from srowen/SPARK-8437 and squashes the following commits: 0e813ae [Sean Owen] Note that 'dir/*' can be more efficient in some Hadoop FS implementations that 'dir/' --- core/src/main/scala/org/apache/spark/SparkContext.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b3c3bf3746e18..cb7e24c374152 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -831,6 +831,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * @note Small files are preferred, large file is also allowable, but may cause bad performance. + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files in a directory + * rather than `.../path/` or `.../path` * * @param minPartitions A suggestion value of the minimal splitting number for input data. */ @@ -878,9 +880,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * (a-hdfs-path/part-nnnnn, its content) * }}} * - * @param minPartitions A suggestion value of the minimal splitting number for input data. - * * @note Small files are preferred; very large files may cause bad performance. + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files in a directory + * rather than `.../path/` or `.../path` + * + * @param minPartitions A suggestion value of the minimal splitting number for input data. */ @Experimental def binaryFiles( From d7f796da45d9a7c76ee4c29a9e0661ef76d8028a Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 29 Jun 2015 17:27:02 -0700 Subject: [PATCH 066/122] [SPARK-8410] [SPARK-8475] remove previous ivy resolution when using spark-submit This PR also includes re-ordering the order that repositories are used when resolving packages. User provided repositories will be prioritized. cc andrewor14 Author: Burak Yavuz Closes #7089 from brkyvz/delete-prev-ivy-resolution and squashes the following commits: a21f95a [Burak Yavuz] remove previous ivy resolution when using spark-submit --- .../org/apache/spark/deploy/SparkSubmit.scala | 37 ++++++++++++------- .../spark/deploy/SparkSubmitUtilsSuite.scala | 6 +-- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index abf222757a95b..b1d6ec209d62b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -756,6 +756,20 @@ private[spark] object SparkSubmitUtils { val cr = new ChainResolver cr.setName("list") + val repositoryList = remoteRepos.getOrElse("") + // add any other remote repositories other than maven central + if (repositoryList.trim.nonEmpty) { + repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => + val brr: IBiblioResolver = new IBiblioResolver + brr.setM2compatible(true) + brr.setUsepoms(true) + brr.setRoot(repo) + brr.setName(s"repo-${i + 1}") + cr.add(brr) + printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") + } + } + val localM2 = new IBiblioResolver localM2.setM2compatible(true) localM2.setRoot(m2Path.toURI.toString) @@ -786,20 +800,6 @@ private[spark] object SparkSubmitUtils { sp.setRoot("http://dl.bintray.com/spark-packages/maven") sp.setName("spark-packages") cr.add(sp) - - val repositoryList = remoteRepos.getOrElse("") - // add any other remote repositories other than maven central - if (repositoryList.trim.nonEmpty) { - repositoryList.split(",").zipWithIndex.foreach { case (repo, i) => - val brr: IBiblioResolver = new IBiblioResolver - brr.setM2compatible(true) - brr.setUsepoms(true) - brr.setRoot(repo) - brr.setName(s"repo-${i + 1}") - cr.add(brr) - printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}") - } - } cr } @@ -922,6 +922,15 @@ private[spark] object SparkSubmitUtils { // A Module descriptor must be specified. Entries are dummy strings val md = getModuleDescriptor + // clear ivy resolution from previous launches. The resolution file is usually at + // ~/.ivy2/org.apache.spark-spark-submit-parent-default.xml. In between runs, this file + // leads to confusion with Ivy when the files can no longer be found at the repository + // declared in that file/ + val mdId = md.getModuleRevisionId + val previousResolution = new File(ivySettings.getDefaultCache, + s"${mdId.getOrganisation}-${mdId.getName}-$ivyConfName.xml") + if (previousResolution.exists) previousResolution.delete + md.setDefaultConf(ivyConfName) // Add exclusion rules for Spark and Scala Library diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 12c40f0b7d658..c9b435a9228d3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -77,9 +77,9 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(resolver2.getResolvers.size() === 7) val expected = repos.split(",").map(r => s"$r/") resolver2.getResolvers.toArray.zipWithIndex.foreach { case (resolver: AbstractResolver, i) => - if (i > 3) { - assert(resolver.getName === s"repo-${i - 3}") - assert(resolver.asInstanceOf[IBiblioResolver].getRoot === expected(i - 4)) + if (i < 3) { + assert(resolver.getName === s"repo-${i + 1}") + assert(resolver.asInstanceOf[IBiblioResolver].getRoot === expected(i)) } } } From 4a9e03fa850af9e4ee56d011671faa04fb601170 Mon Sep 17 00:00:00 2001 From: Michael Sannella x268 Date: Mon, 29 Jun 2015 17:28:28 -0700 Subject: [PATCH 067/122] [SPARK-8019] [SPARKR] Support SparkR spawning worker R processes with a command other then Rscript This is a simple change to add a new environment variable "spark.sparkr.r.command" that specifies the command that SparkR will use when creating an R engine process. If this is not specified, "Rscript" will be used by default. I did not add any documentation, since I couldn't find any place where environment variables (such as "spark.sparkr.use.daemon") are documented. I also did not add a unit test. The only test that would work generally would be one starting SparkR with sparkR.init(sparkEnvir=list(spark.sparkr.r.command="Rscript")), just using the default value. I think that this is a low-risk change. Likely committers: shivaram Author: Michael Sannella x268 Closes #6557 from msannell/altR and squashes the following commits: 7eac142 [Michael Sannella x268] add spark.sparkr.r.command config parameter --- core/src/main/scala/org/apache/spark/api/r/RRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 4dfa7325934ff..524676544d6f5 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -391,7 +391,7 @@ private[r] object RRDD { } private def createRProcess(rLibDir: String, port: Int, script: String): BufferedStreamThread = { - val rCommand = "Rscript" + val rCommand = SparkEnv.get.conf.get("spark.sparkr.r.command", "Rscript") val rOptions = "--vanilla" val rExecScript = rLibDir + "/SparkR/worker/" + script val pb = new ProcessBuilder(List(rCommand, rOptions, rExecScript)) From 4c1808be4d3aaa37a5a878892e91ca73ea405ffa Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 29 Jun 2015 18:32:31 -0700 Subject: [PATCH 068/122] Revert "[SPARK-8437] [DOCS] Using directory path without wildcard for filename slow for large number of files with wholeTextFiles and binaryFiles" This reverts commit 5d30eae56051c563a8427f330b09ef66db0a0d21. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cb7e24c374152..b3c3bf3746e18 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -831,8 +831,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * @note Small files are preferred, large file is also allowable, but may cause bad performance. - * @note On some filesystems, `.../path/*` can be a more efficient way to read all files in a directory - * rather than `.../path/` or `.../path` * * @param minPartitions A suggestion value of the minimal splitting number for input data. */ @@ -880,11 +878,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * (a-hdfs-path/part-nnnnn, its content) * }}} * - * @note Small files are preferred; very large files may cause bad performance. - * @note On some filesystems, `.../path/*` can be a more efficient way to read all files in a directory - * rather than `.../path/` or `.../path` - * * @param minPartitions A suggestion value of the minimal splitting number for input data. + * + * @note Small files are preferred; very large files may cause bad performance. */ @Experimental def binaryFiles( From 620605a4a1123afaab2674e38251f1231dea17ce Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Mon, 29 Jun 2015 18:40:30 -0700 Subject: [PATCH 069/122] [SPARK-8456] [ML] Ngram featurizer python Python API for N-gram feature transformer Author: Feynman Liang Closes #6960 from feynmanliang/ngram-featurizer-python and squashes the following commits: f9e37c9 [Feynman Liang] Remove debugging code 4dd81f4 [Feynman Liang] Fix typo and doctest 06c79ac [Feynman Liang] Style guide 26c1175 [Feynman Liang] Add python NGram API --- python/pyspark/ml/feature.py | 71 +++++++++++++++++++++++++++++++++++- python/pyspark/ml/tests.py | 11 ++++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index ddb33f427ac64..8804dace849b3 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -21,7 +21,7 @@ from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer from pyspark.mllib.common import inherit_doc -__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'Normalizer', 'OneHotEncoder', +__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', 'StringIndexer', 'StringIndexerModel', 'Tokenizer', 'VectorAssembler', 'VectorIndexer', 'Word2Vec', 'Word2VecModel'] @@ -265,6 +265,75 @@ class IDFModel(JavaModel): """ +@inherit_doc +@ignore_unicode_prefix +class NGram(JavaTransformer, HasInputCol, HasOutputCol): + """ + A feature transformer that converts the input array of strings into an array of n-grams. Null + values in the input array are ignored. + It returns an array of n-grams where each n-gram is represented by a space-separated string of + words. + When the input is empty, an empty array is returned. + When the input array length is less than n (number of elements per n-gram), no n-grams are + returned. + + >>> df = sqlContext.createDataFrame([Row(inputTokens=["a", "b", "c", "d", "e"])]) + >>> ngram = NGram(n=2, inputCol="inputTokens", outputCol="nGrams") + >>> ngram.transform(df).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b', u'b c', u'c d', u'd e']) + >>> # Change n-gram length + >>> ngram.setParams(n=4).transform(df).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e']) + >>> # Temporarily modify output column. + >>> ngram.transform(df, {ngram.outputCol: "output"}).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], output=[u'a b c d', u'b c d e']) + >>> ngram.transform(df).head() + Row(inputTokens=[u'a', u'b', u'c', u'd', u'e'], nGrams=[u'a b c d', u'b c d e']) + >>> # Must use keyword arguments to specify params. + >>> ngram.setParams("text") + Traceback (most recent call last): + ... + TypeError: Method setParams forces keyword arguments. + """ + + # a placeholder to make it appear in the generated doc + n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)") + + @keyword_only + def __init__(self, n=2, inputCol=None, outputCol=None): + """ + __init__(self, n=2, inputCol=None, outputCol=None) + """ + super(NGram, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid) + self.n = Param(self, "n", "number of elements per n-gram (>=1)") + self._setDefault(n=2) + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, n=2, inputCol=None, outputCol=None): + """ + setParams(self, n=2, inputCol=None, outputCol=None) + Sets params for this NGram. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + def setN(self, value): + """ + Sets the value of :py:attr:`n`. + """ + self._paramMap[self.n] = value + return self + + def getN(self): + """ + Gets the value of n or its default value. + """ + return self.getOrDefault(self.n) + + @inherit_doc class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): """ diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 6adbf166f34a8..c151d21fd661a 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -252,6 +252,17 @@ def test_idf(self): output = idf0m.transform(dataset) self.assertIsNotNone(output.head().idf) + def test_ngram(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + ([["a", "b", "c", "d", "e"]])], ["input"]) + ngram0 = NGram(n=4, inputCol="input", outputCol="output") + self.assertEqual(ngram0.getN(), 4) + self.assertEqual(ngram0.getInputCol(), "input") + self.assertEqual(ngram0.getOutputCol(), "output") + transformedDF = ngram0.transform(dataset) + self.assertEquals(transformedDF.head().output, ["a b c d", "b c d e"]) + if __name__ == "__main__": unittest.main() From ecacb1e88a135c802e253793e7c863d6ca8d2408 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Mon, 29 Jun 2015 18:48:28 -0700 Subject: [PATCH 070/122] [SPARK-8715] ArrayOutOfBoundsException fixed for DataFrameStatSuite.crosstab cc yhuai Author: Burak Yavuz Closes #7100 from brkyvz/ct-flakiness-fix and squashes the following commits: abc299a [Burak Yavuz] change 'to' to until 7e96d7c [Burak Yavuz] ArrayOutOfBoundsException fixed for DataFrameStatSuite.crosstab --- .../test/scala/org/apache/spark/sql/DataFrameStatSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 64ec1a70c47e6..765094da6bda7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -78,7 +78,7 @@ class DataFrameStatSuite extends SparkFunSuite { val rows = crosstab.collect() rows.foreach { row => val i = row.getString(0).toInt - for (col <- 1 to 9) { + for (col <- 1 until columnNames.length) { val j = columnNames(col).toInt assert(row.getLong(col) === expected.getOrElse((i, j), 0).toLong) } From 4915e9e3bffb57eac319ef2173b4a6ae4073d25e Mon Sep 17 00:00:00 2001 From: Steven She Date: Mon, 29 Jun 2015 18:50:09 -0700 Subject: [PATCH 071/122] [SPARK-8669] [SQL] Fix crash with BINARY (ENUM) fields with Parquet 1.7 Patch to fix crash with BINARY fields with ENUM original types. Author: Steven She Closes #7048 from stevencanopy/SPARK-8669 and squashes the following commits: 2e72979 [Steven She] [SPARK-8669] [SQL] Fix crash with BINARY (ENUM) fields with Parquet 1.7 --- .../spark/sql/parquet/CatalystSchemaConverter.scala | 2 +- .../org/apache/spark/sql/parquet/ParquetSchemaSuite.scala | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala index 4fd3e93b70311..2be7c64612cd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystSchemaConverter.scala @@ -177,7 +177,7 @@ private[parquet] class CatalystSchemaConverter( case BINARY => field.getOriginalType match { - case UTF8 => StringType + case UTF8 | ENUM => StringType case null if assumeBinaryIsString => StringType case null => BinaryType case DECIMAL => makeDecimalType() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala index d0bfcde7e032b..35d3c33f99a06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala @@ -161,6 +161,14 @@ class ParquetSchemaInferenceSuite extends ParquetSchemaTest { """.stripMargin, binaryAsString = true) + testSchemaInference[Tuple1[String]]( + "binary enum as string", + """ + |message root { + | optional binary _1 (ENUM); + |} + """.stripMargin) + testSchemaInference[Tuple1[Seq[Int]]]( "non-nullable array - non-standard", """ From f9b6bf2f83d9dad273aa36d65d0560d35b941cc2 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Mon, 29 Jun 2015 18:50:23 -0700 Subject: [PATCH 072/122] [SPARK-7667] [MLLIB] MLlib Python API consistency check MLlib Python API consistency check Author: Yanbo Liang Closes #6856 from yanboliang/spark-7667 and squashes the following commits: 21bae35 [Yanbo Liang] remove duplicate code eb12f95 [Yanbo Liang] fix doc inherit problem 9e7ec3c [Yanbo Liang] address comments e763d32 [Yanbo Liang] MLlib Python API consistency check --- python/pyspark/mllib/feature.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index f00bb93b7bf40..b5138773fd61b 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -111,6 +111,15 @@ class JavaVectorTransformer(JavaModelWrapper, VectorTransformer): """ def transform(self, vector): + """ + Applies transformation on a vector or an RDD[Vector]. + + Note: In Python, transform cannot currently be used within + an RDD transformation or action. + Call transform directly on the RDD instead. + + :param vector: Vector or RDD of Vector to be transformed. + """ if isinstance(vector, RDD): vector = vector.map(_convert_to_vector) else: @@ -191,7 +200,7 @@ def fit(self, dataset): Computes the mean and variance and stores as a model to be used for later scaling. - :param data: The data used to compute the mean and variance + :param dataset: The data used to compute the mean and variance to build the transformation model. :return: a StandardScalarModel """ @@ -346,10 +355,6 @@ def transform(self, x): vector :return: an RDD of TF-IDF vectors or a TF-IDF vector """ - if isinstance(x, RDD): - return JavaVectorTransformer.transform(self, x) - - x = _convert_to_vector(x) return JavaVectorTransformer.transform(self, x) def idf(self): From 7bbbe380c52419cd580d1c99c10131184e4ad440 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 29 Jun 2015 21:32:40 -0700 Subject: [PATCH 073/122] [SPARK-5161] Parallelize Python test execution This commit parallelizes the Python unit test execution, significantly reducing Jenkins build times. Parallelism is now configurable by passing the `-p` or `--parallelism` flags to either `dev/run-tests` or `python/run-tests` (the default parallelism is 4, but I've successfully tested with higher parallelism). To avoid flakiness, I've disabled the Spark Web UI for the Python tests, similar to what we've done for the JVM tests. Author: Josh Rosen Closes #7031 from JoshRosen/parallelize-python-tests and squashes the following commits: feb3763 [Josh Rosen] Re-enable other tests f87ea81 [Josh Rosen] Only log output from failed tests d4ded73 [Josh Rosen] Logging improvements a2717e1 [Josh Rosen] Make parallelism configurable via dev/run-tests 1bacf1b [Josh Rosen] Merge remote-tracking branch 'origin/master' into parallelize-python-tests 110cd9d [Josh Rosen] Fix universal_newlines for Python 3 cd13db8 [Josh Rosen] Also log python_implementation 9e31127 [Josh Rosen] Log Python --version output for each executable. a2b9094 [Josh Rosen] Bump up parallelism. 5552380 [Josh Rosen] Python 3 fix 866b5b9 [Josh Rosen] Fix lazy logging warnings in Prospector checks 87cb988 [Josh Rosen] Skip MLLib tests for PyPy 8309bfe [Josh Rosen] Temporarily disable parallelism to debug a failure 9129027 [Josh Rosen] Disable Spark UI in Python tests 037b686 [Josh Rosen] Temporarily disable JVM tests so we can test Python speedup in Jenkins. af4cef4 [Josh Rosen] Initial attempt at parallelizing Python test execution --- dev/run-tests | 2 +- dev/run-tests.py | 24 +++++++- dev/sparktestsupport/shellutils.py | 1 + python/pyspark/java_gateway.py | 2 + python/run-tests.py | 97 +++++++++++++++++++++++------- 5 files changed, 101 insertions(+), 25 deletions(-) diff --git a/dev/run-tests b/dev/run-tests index a00d9f0c27639..257d1e8d50bb4 100755 --- a/dev/run-tests +++ b/dev/run-tests @@ -20,4 +20,4 @@ FWDIR="$(cd "`dirname $0`"/..; pwd)" cd "$FWDIR" -exec python -u ./dev/run-tests.py +exec python -u ./dev/run-tests.py "$@" diff --git a/dev/run-tests.py b/dev/run-tests.py index e5c897b94d167..4596e07014733 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -19,6 +19,7 @@ from __future__ import print_function import itertools +from optparse import OptionParser import os import re import sys @@ -360,12 +361,13 @@ def run_scala_tests(build_tool, hadoop_version, test_modules): run_scala_tests_sbt(test_modules, test_profiles) -def run_python_tests(test_modules): +def run_python_tests(test_modules, parallelism): set_title_and_block("Running PySpark tests", "BLOCK_PYSPARK_UNIT_TESTS") command = [os.path.join(SPARK_HOME, "python", "run-tests")] if test_modules != [modules.root]: command.append("--modules=%s" % ','.join(m.name for m in test_modules)) + command.append("--parallelism=%i" % parallelism) run_cmd(command) @@ -379,7 +381,25 @@ def run_sparkr_tests(): print("Ignoring SparkR tests as R was not found in PATH") +def parse_opts(): + parser = OptionParser( + prog="run-tests" + ) + parser.add_option( + "-p", "--parallelism", type="int", default=4, + help="The number of suites to test in parallel (default %default)" + ) + + (opts, args) = parser.parse_args() + if args: + parser.error("Unsupported arguments: %s" % ' '.join(args)) + if opts.parallelism < 1: + parser.error("Parallelism cannot be less than 1") + return opts + + def main(): + opts = parse_opts() # Ensure the user home directory (HOME) is valid and is an absolute directory if not USER_HOME or not os.path.isabs(USER_HOME): print("[error] Cannot determine your home directory as an absolute path;", @@ -461,7 +481,7 @@ def main(): modules_with_python_tests = [m for m in test_modules if m.python_test_goals] if modules_with_python_tests: - run_python_tests(modules_with_python_tests) + run_python_tests(modules_with_python_tests, opts.parallelism) if any(m.should_run_r_tests for m in test_modules): run_sparkr_tests() diff --git a/dev/sparktestsupport/shellutils.py b/dev/sparktestsupport/shellutils.py index ad9b0cc89e4ab..12bd0bf3a4fe9 100644 --- a/dev/sparktestsupport/shellutils.py +++ b/dev/sparktestsupport/shellutils.py @@ -15,6 +15,7 @@ # limitations under the License. # +from __future__ import print_function import os import shutil import subprocess diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index 3cee4ea6e3a35..90cd342a6cf7f 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -51,6 +51,8 @@ def launch_gateway(): on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" submit_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + if os.environ.get("SPARK_TESTING"): + submit_args = "--conf spark.ui.enabled=false " + submit_args command = [os.path.join(SPARK_HOME, script)] + shlex.split(submit_args) # Start a socket that will be used by PythonGatewayServer to communicate its port to us diff --git a/python/run-tests.py b/python/run-tests.py index 7d485b500ee3a..aaa35e936a806 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -18,12 +18,19 @@ # from __future__ import print_function +import logging from optparse import OptionParser import os import re import subprocess import sys +import tempfile +from threading import Thread, Lock import time +if sys.version < '3': + import Queue +else: + import queue as Queue # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -43,34 +50,44 @@ def print_red(text): LOG_FILE = os.path.join(SPARK_HOME, "python/unit-tests.log") +FAILURE_REPORTING_LOCK = Lock() +LOGGER = logging.getLogger() def run_individual_python_test(test_name, pyspark_python): env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)} - print(" Running test: %s ..." % test_name, end='') + LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() - with open(LOG_FILE, 'a') as log_file: - retcode = subprocess.call( - [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], - stderr=log_file, stdout=log_file, env=env) + per_test_output = tempfile.TemporaryFile() + retcode = subprocess.Popen( + [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], + stderr=per_test_output, stdout=per_test_output, env=env).wait() duration = time.time() - start_time # Exit on the first failure. if retcode != 0: - with open(LOG_FILE, 'r') as log_file: - for line in log_file: + with FAILURE_REPORTING_LOCK: + with open(LOG_FILE, 'ab') as log_file: + per_test_output.seek(0) + log_file.writelines(per_test_output.readlines()) + per_test_output.seek(0) + for line in per_test_output: if not re.match('[0-9]+', line): print(line, end='') - print_red("\nHad test failures in %s; see logs." % test_name) - exit(-1) + per_test_output.close() + print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python)) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(-1) else: - print("ok (%is)" % duration) + per_test_output.close() + LOGGER.info("Finished test(%s): %s (%is)", pyspark_python, test_name, duration) def get_default_python_executables(): python_execs = [x for x in ["python2.6", "python3.4", "pypy"] if which(x)] if "python2.6" not in python_execs: - print("WARNING: Not testing against `python2.6` because it could not be found; falling" - " back to `python` instead") + LOGGER.warning("Not testing against `python2.6` because it could not be found; falling" + " back to `python` instead") python_execs.insert(0, "python") return python_execs @@ -88,16 +105,31 @@ def parse_opts(): default=",".join(sorted(python_modules.keys())), help="A comma-separated list of Python modules to test (default: %default)" ) + parser.add_option( + "-p", "--parallelism", type="int", default=4, + help="The number of suites to test in parallel (default %default)" + ) + parser.add_option( + "--verbose", action="store_true", + help="Enable additional debug logging" + ) (opts, args) = parser.parse_args() if args: parser.error("Unsupported arguments: %s" % ' '.join(args)) + if opts.parallelism < 1: + parser.error("Parallelism cannot be less than 1") return opts def main(): opts = parse_opts() - print("Running PySpark tests. Output is in python/%s" % LOG_FILE) + if (opts.verbose): + log_level = logging.DEBUG + else: + log_level = logging.INFO + logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s") + LOGGER.info("Running PySpark tests. Output is in python/%s", LOG_FILE) if os.path.exists(LOG_FILE): os.remove(LOG_FILE) python_execs = opts.python_executables.split(',') @@ -108,24 +140,45 @@ def main(): else: print("Error: unrecognized module %s" % module_name) sys.exit(-1) - print("Will test against the following Python executables: %s" % python_execs) - print("Will test the following Python modules: %s" % [x.name for x in modules_to_test]) + LOGGER.info("Will test against the following Python executables: %s", python_execs) + LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test]) - start_time = time.time() + task_queue = Queue.Queue() for python_exec in python_execs: python_implementation = subprocess.check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], universal_newlines=True).strip() - print("Testing with `%s`: " % python_exec, end='') - subprocess.call([python_exec, "--version"]) - + LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation) + LOGGER.debug("%s version is: %s", python_exec, subprocess.check_output( + [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip()) for module in modules_to_test: if python_implementation not in module.blacklisted_python_implementations: - print("Running %s tests ..." % module.name) for test_goal in module.python_test_goals: - run_individual_python_test(test_goal, python_exec) + task_queue.put((python_exec, test_goal)) + + def process_queue(task_queue): + while True: + try: + (python_exec, test_goal) = task_queue.get_nowait() + except Queue.Empty: + break + try: + run_individual_python_test(test_goal, python_exec) + finally: + task_queue.task_done() + + start_time = time.time() + for _ in range(opts.parallelism): + worker = Thread(target=process_queue, args=(task_queue,)) + worker.daemon = True + worker.start() + try: + task_queue.join() + except (KeyboardInterrupt, SystemExit): + print_red("Exiting due to interrupt") + sys.exit(-1) total_duration = time.time() - start_time - print("Tests passed in %i seconds" % total_duration) + LOGGER.info("Tests passed in %i seconds", total_duration) if __name__ == "__main__": From ea775b0662b952849ac7fe2026fc3fd4714c37e3 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 29 Jun 2015 21:41:59 -0700 Subject: [PATCH 074/122] MAINTENANCE: Automated closing of pull requests. This commit exists to close the following pull requests on Github: Closes #1767 (close requested by 'andrewor14') Closes #6952 (close requested by 'andrewor14') Closes #7051 (close requested by 'andrewor14') Closes #5357 (close requested by 'marmbrus') Closes #5233 (close requested by 'andrewor14') Closes #6930 (close requested by 'JoshRosen') Closes #5502 (close requested by 'andrewor14') Closes #6778 (close requested by 'andrewor14') Closes #7006 (close requested by 'andrewor14') From f79410c49b2225b2acdc58293574860230987775 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 29 Jun 2015 22:32:43 -0700 Subject: [PATCH 075/122] [SPARK-8721][SQL] Rename ExpectsInputTypes => AutoCastInputTypes. Author: Reynold Xin Closes #7109 from rxin/auto-cast and squashes the following commits: a914cc3 [Reynold Xin] [SPARK-8721][SQL] Rename ExpectsInputTypes => AutoCastInputTypes. --- .../catalyst/analysis/HiveTypeCoercion.scala | 8 +- .../sql/catalyst/expressions/Expression.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 118 ++++++++---------- .../spark/sql/catalyst/expressions/misc.scala | 6 +- .../sql/catalyst/expressions/predicates.scala | 6 +- .../expressions/stringOperations.scala | 10 +- 6 files changed, 71 insertions(+), 79 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 976fa57cb98d5..c3d68197d64ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -116,7 +116,7 @@ trait HiveTypeCoercion { IfCoercion :: Division :: PropagateTypes :: - ExpectedInputConversion :: + AddCastForAutoCastInputTypes :: Nil /** @@ -709,15 +709,15 @@ trait HiveTypeCoercion { /** * Casts types according to the expected input types for Expressions that have the trait - * `ExpectsInputTypes`. + * [[AutoCastInputTypes]]. */ - object ExpectedInputConversion extends Rule[LogicalPlan] { + object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case e: ExpectsInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => + case e: AutoCastInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { case (child, actual, expected) => if (actual == expected) child else Cast(child, expected) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index f59db3d5dfc23..e5dc7b9b5c884 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -261,7 +261,7 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio * Expressions that require a specific `DataType` as input should implement this trait * so that the proper type conversions can be performed in the analyzer. */ -trait ExpectsInputTypes { +trait AutoCastInputTypes { self: Expression => def expectedChildTypes: Seq[DataType] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 4b57ddd9c5768..a022f3727bd58 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -56,7 +56,7 @@ abstract class LeafMathExpression(c: Double, name: String) * @param name The short name of the function */ abstract class UnaryMathExpression(f: Double => Double, name: String) - extends UnaryExpression with Serializable with ExpectsInputTypes { + extends UnaryExpression with Serializable with AutoCastInputTypes { self: Product => override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) @@ -99,7 +99,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) * @param name The short name of the function */ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) - extends BinaryExpression with Serializable with ExpectsInputTypes { self: Product => + extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product => override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) @@ -211,19 +211,11 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia } case class Bin(child: Expression) - extends UnaryExpression with Serializable with ExpectsInputTypes { - - val name: String = "BIN" - - override def foldable: Boolean = child.foldable - override def nullable: Boolean = true - override def toString: String = s"$name($child)" + extends UnaryExpression with Serializable with AutoCastInputTypes { override def expectedChildTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType - def funcName: String = name.toLowerCase - override def eval(input: InternalRow): Any = { val evalE = child.eval(input) if (evalE == null) { @@ -239,61 +231,13 @@ case class Bin(child: Expression) } } -//////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// -// Binary math functions -//////////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// - - -case class Atan2(left: Expression, right: Expression) - extends BinaryMathExpression(math.atan2, "ATAN2") { - - override def eval(input: InternalRow): Any = { - val evalE1 = left.eval(input) - if (evalE1 == null) { - null - } else { - val evalE2 = right.eval(input) - if (evalE2 == null) { - null - } else { - // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 - val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, - evalE2.asInstanceOf[Double] + 0.0) - if (result.isNaN) null else result - } - } - } - - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ - } -} - -case class Pow(left: Expression, right: Expression) - extends BinaryMathExpression(math.pow, "POWER") { - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" - if (Double.valueOf(${ev.primitive}).isNaN()) { - ${ev.isNull} = true; - } - """ - } -} /** * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. - * Otherwise if the number is a STRING, - * it converts each character into its hexadecimal representation and returns the resulting STRING. - * Negative numbers would be treated as two's complement. + * Otherwise if the number is a STRING, it converts each character into its hex representation + * and returns the resulting STRING. Negative numbers would be treated as two's complement. */ -case class Hex(child: Expression) - extends UnaryExpression with Serializable { +case class Hex(child: Expression) extends UnaryExpression with Serializable { override def dataType: DataType = StringType @@ -337,7 +281,7 @@ case class Hex(child: Expression) private def doHex(bytes: Array[Byte], length: Int): UTF8String = { val value = new Array[Byte](length * 2) var i = 0 - while(i < length) { + while (i < length) { value(i * 2) = Character.toUpperCase(Character.forDigit( (bytes(i) & 0xF0) >>> 4, 16)).toByte value(i * 2 + 1) = Character.toUpperCase(Character.forDigit( @@ -362,6 +306,54 @@ case class Hex(child: Expression) } } + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Binary math functions +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +case class Atan2(left: Expression, right: Expression) + extends BinaryMathExpression(math.atan2, "ATAN2") { + + override def eval(input: InternalRow): Any = { + val evalE1 = left.eval(input) + if (evalE1 == null) { + null + } else { + val evalE2 = right.eval(input) + if (evalE2 == null) { + null + } else { + // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0 + val result = math.atan2(evalE1.asInstanceOf[Double] + 0.0, + evalE2.asInstanceOf[Double] + 0.0) + if (result.isNaN) null else result + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} + +case class Pow(left: Expression, right: Expression) + extends BinaryMathExpression(math.pow, "POWER") { + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s""" + if (Double.valueOf(${ev.primitive}).isNaN()) { + ${ev.isNull} = true; + } + """ + } +} + case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 9a39165a1ff05..27805bff293f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -31,7 +31,7 @@ import org.apache.spark.unsafe.types.UTF8String * For input of type [[BinaryType]] */ case class Md5(child: Expression) - extends UnaryExpression with ExpectsInputTypes { + extends UnaryExpression with AutoCastInputTypes { override def dataType: DataType = StringType @@ -61,7 +61,7 @@ case class Md5(child: Expression) * the hash length is not one of the permitted values, the return value is NULL. */ case class Sha2(left: Expression, right: Expression) - extends BinaryExpression with Serializable with ExpectsInputTypes { + extends BinaryExpression with Serializable with AutoCastInputTypes { override def dataType: DataType = StringType @@ -146,7 +146,7 @@ case class Sha2(left: Expression, right: Expression) * A function that calculates a sha1 hash value and returns it as a hex string * For input of type [[BinaryType]] or [[StringType]] */ -case class Sha1(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTypes { override def dataType: DataType = StringType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 3a12d03ba6bb9..386cf6a8df6df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -70,7 +70,7 @@ trait PredicateHelper { } -case class Not(child: Expression) extends UnaryExpression with Predicate with ExpectsInputTypes { +case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes { override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable override def toString: String = s"NOT $child" @@ -123,7 +123,7 @@ case class InSet(value: Expression, hset: Set[Any]) } case class And(left: Expression, right: Expression) - extends BinaryExpression with Predicate with ExpectsInputTypes { + extends BinaryExpression with Predicate with AutoCastInputTypes { override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) @@ -172,7 +172,7 @@ case class And(left: Expression, right: Expression) } case class Or(left: Expression, right: Expression) - extends BinaryExpression with Predicate with ExpectsInputTypes { + extends BinaryExpression with Predicate with AutoCastInputTypes { override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index a6225fdafedde..ce184e4f32f18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -trait StringRegexExpression extends ExpectsInputTypes { +trait StringRegexExpression extends AutoCastInputTypes { self: BinaryExpression => def escape(v: String): String @@ -111,7 +111,7 @@ case class RLike(left: Expression, right: Expression) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) } -trait CaseConversionExpression extends ExpectsInputTypes { +trait CaseConversionExpression extends AutoCastInputTypes { self: UnaryExpression => def convert(v: UTF8String): UTF8String @@ -158,7 +158,7 @@ case class Lower(child: Expression) extends UnaryExpression with CaseConversionE } /** A base trait for functions that compare two strings, returning a boolean. */ -trait StringComparison extends ExpectsInputTypes { +trait StringComparison extends AutoCastInputTypes { self: BinaryExpression => def compare(l: UTF8String, r: UTF8String): Boolean @@ -221,7 +221,7 @@ case class EndsWith(left: Expression, right: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with ExpectsInputTypes { + extends Expression with AutoCastInputTypes { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) @@ -295,7 +295,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) /** * A function that return the length of the given string expression. */ -case class StringLength(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes { override def dataType: DataType = IntegerType override def expectedChildTypes: Seq[DataType] = Seq(StringType) From e6c3f7462b3fde220ec0084b52388dd4dabb75b9 Mon Sep 17 00:00:00 2001 From: Yadong Qi Date: Mon, 29 Jun 2015 22:34:38 -0700 Subject: [PATCH 076/122] [SPARK-8650] [SQL] Use the user-specified app name priority in SparkSQLCLIDriver or HiveThriftServer2 When run `./bin/spark-sql --name query1.sql` [Before] ![before](https://cloud.githubusercontent.com/assets/1400819/8370336/fa20b75a-1bf8-11e5-9171-040049a53240.png) [After] ![after](https://cloud.githubusercontent.com/assets/1400819/8370189/dcc35cb4-1bf6-11e5-8796-a0694140bffb.png) Author: Yadong Qi Closes #7030 from watermen/SPARK-8650 and squashes the following commits: 51b5134 [Yadong Qi] Improve code and add comment. e3d7647 [Yadong Qi] use spark.app.name priority. --- .../apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 79eda1f5123bf..1d41c46131828 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -38,9 +38,14 @@ private[hive] object SparkSQLEnv extends Logging { val sparkConf = new SparkConf(loadDefaults = true) val maybeSerializer = sparkConf.getOption("spark.serializer") val maybeKryoReferenceTracking = sparkConf.getOption("spark.kryo.referenceTracking") + // If user doesn't specify the appName, we want to get [SparkSQL::localHostName] instead of + // the default appName [SparkSQLCLIDriver] in cli or beeline. + val maybeAppName = sparkConf + .getOption("spark.app.name") + .filterNot(_ == classOf[SparkSQLCLIDriver].getName) sparkConf - .setAppName(s"SparkSQL::${Utils.localHostName()}") + .setAppName(maybeAppName.getOrElse(s"SparkSQL::${Utils.localHostName()}")) .set( "spark.serializer", maybeSerializer.getOrElse("org.apache.spark.serializer.KryoSerializer")) From 6c5a6db4d53d6db8aa3464ea6713cf0d3a3bdfb5 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 29 Jun 2015 23:08:51 -0700 Subject: [PATCH 077/122] [SPARK-5161] [HOTFIX] Fix bug in Python test failure reporting This patch fixes a bug introduced in #7031 which can cause Jenkins to incorrectly report a build with failed Python tests as passing if an error occurred while printing the test failure message. Author: Josh Rosen Closes #7112 from JoshRosen/python-tests-hotfix and squashes the following commits: c3f2961 [Josh Rosen] Hotfix for bug in Python test failure reporting --- python/run-tests.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/python/run-tests.py b/python/run-tests.py index aaa35e936a806..b7737650daa54 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -58,22 +58,33 @@ def run_individual_python_test(test_name, pyspark_python): env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)} LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name) start_time = time.time() - per_test_output = tempfile.TemporaryFile() - retcode = subprocess.Popen( - [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], - stderr=per_test_output, stdout=per_test_output, env=env).wait() + try: + per_test_output = tempfile.TemporaryFile() + retcode = subprocess.Popen( + [os.path.join(SPARK_HOME, "bin/pyspark"), test_name], + stderr=per_test_output, stdout=per_test_output, env=env).wait() + except: + LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python) + # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if + # this code is invoked from a thread other than the main thread. + os._exit(1) duration = time.time() - start_time # Exit on the first failure. if retcode != 0: - with FAILURE_REPORTING_LOCK: - with open(LOG_FILE, 'ab') as log_file: + try: + with FAILURE_REPORTING_LOCK: + with open(LOG_FILE, 'ab') as log_file: + per_test_output.seek(0) + log_file.writelines(per_test_output) per_test_output.seek(0) - log_file.writelines(per_test_output.readlines()) - per_test_output.seek(0) - for line in per_test_output: - if not re.match('[0-9]+', line): - print(line, end='') - per_test_output.close() + for line in per_test_output: + decoded_line = line.decode() + if not re.match('[0-9]+', decoded_line): + print(decoded_line, end='') + per_test_output.close() + except: + LOGGER.exception("Got an exception while trying to print failed test output") + finally: print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python)) # Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if # this code is invoked from a thread other than the main thread. From 12671dd5e468beedc2681ff2bdf95fba81f8f29c Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 29 Jun 2015 23:44:11 -0700 Subject: [PATCH 078/122] [SPARK-8434][SQL]Add a "pretty" parameter to the "show" method to display long strings Sometimes the user may want to show the complete content of cells. Now `sql("set -v").show()` displays: ![screen shot 2015-06-18 at 4 34 51 pm](https://cloud.githubusercontent.com/assets/1000778/8227339/14d3c5ea-15d9-11e5-99b9-f00b7e93beef.png) The user needs to use something like `sql("set -v").collect().foreach(r => r.toSeq.mkString("\t"))` to show the complete content. This PR adds a `pretty` parameter to show. If `pretty` is false, `show` won't truncate strings or align cells right. ![screen shot 2015-06-18 at 4 21 44 pm](https://cloud.githubusercontent.com/assets/1000778/8227407/b6f8dcac-15d9-11e5-8219-8079280d76fc.png) Author: zsxwing Closes #6877 from zsxwing/show and squashes the following commits: 22e28e9 [zsxwing] pretty -> truncate e582628 [zsxwing] Add pretty parameter to the show method in R a3cd55b [zsxwing] Fix calling showString in R 923cee4 [zsxwing] Add a "pretty" parameter to show to display long strings --- R/pkg/R/DataFrame.R | 4 +- python/pyspark/sql/dataframe.py | 7 ++- .../org/apache/spark/sql/DataFrame.scala | 55 ++++++++++++++++--- .../org/apache/spark/sql/DataFrameSuite.scala | 21 +++++++ 4 files changed, 76 insertions(+), 11 deletions(-) diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 6feabf4189c2d..60702824acb46 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -169,8 +169,8 @@ setMethod("isLocal", #'} setMethod("showDF", signature(x = "DataFrame"), - function(x, numRows = 20) { - s <- callJMethod(x@sdf, "showString", numToInt(numRows)) + function(x, numRows = 20, truncate = TRUE) { + s <- callJMethod(x@sdf, "showString", numToInt(numRows), truncate) cat(s) }) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 152b87351db31..4b9efa0a210fb 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -247,9 +247,12 @@ def isLocal(self): return self._jdf.isLocal() @since(1.3) - def show(self, n=20): + def show(self, n=20, truncate=True): """Prints the first ``n`` rows to the console. + :param n: Number of rows to show. + :param truncate: Whether truncate long strings and align cells right. + >>> df DataFrame[age: int, name: string] >>> df.show() @@ -260,7 +263,7 @@ def show(self, n=20): | 5| Bob| +---+-----+ """ - print(self._jdf.showString(n)) + print(self._jdf.showString(n, truncate)) def __repr__(self): return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 986e59133919f..8fe1f7e34cb5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -169,8 +169,9 @@ class DataFrame private[sql]( /** * Internal API for Python * @param _numRows Number of rows to show + * @param truncate Whether truncate long strings and align cells right */ - private[sql] def showString(_numRows: Int): String = { + private[sql] def showString(_numRows: Int, truncate: Boolean = true): String = { val numRows = _numRows.max(0) val sb = new StringBuilder val takeResult = take(numRows + 1) @@ -188,7 +189,7 @@ class DataFrame private[sql]( case seq: Seq[_] => seq.mkString("[", ", ", "]") case _ => cell.toString } - if (str.length > 20) str.substring(0, 17) + "..." else str + if (truncate && str.length > 20) str.substring(0, 17) + "..." else str }: Seq[String] } @@ -207,7 +208,11 @@ class DataFrame private[sql]( // column names rows.head.zipWithIndex.map { case (cell, i) => - StringUtils.leftPad(cell, colWidths(i)) + if (truncate) { + StringUtils.leftPad(cell, colWidths(i)) + } else { + StringUtils.rightPad(cell, colWidths(i)) + } }.addString(sb, "|", "|", "|\n") sb.append(sep) @@ -215,7 +220,11 @@ class DataFrame private[sql]( // data rows.tail.map { _.zipWithIndex.map { case (cell, i) => - StringUtils.leftPad(cell.toString, colWidths(i)) + if (truncate) { + StringUtils.leftPad(cell.toString, colWidths(i)) + } else { + StringUtils.rightPad(cell.toString, colWidths(i)) + } }.addString(sb, "|", "|", "|\n") } @@ -331,7 +340,8 @@ class DataFrame private[sql]( def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation] /** - * Displays the [[DataFrame]] in a tabular form. For example: + * Displays the [[DataFrame]] in a tabular form. Strings more than 20 characters will be + * truncated, and all cells will be aligned right. For example: * {{{ * year month AVG('Adj Close) MAX('Adj Close) * 1980 12 0.503218 0.595103 @@ -345,15 +355,46 @@ class DataFrame private[sql]( * @group action * @since 1.3.0 */ - def show(numRows: Int): Unit = println(showString(numRows)) + def show(numRows: Int): Unit = show(numRows, true) /** - * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * Displays the top 20 rows of [[DataFrame]] in a tabular form. Strings more than 20 characters + * will be truncated, and all cells will be aligned right. * @group action * @since 1.3.0 */ def show(): Unit = show(20) + /** + * Displays the top 20 rows of [[DataFrame]] in a tabular form. + * + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @group action + * @since 1.5.0 + */ + def show(truncate: Boolean): Unit = show(20, truncate) + + /** + * Displays the [[DataFrame]] in a tabular form. For example: + * {{{ + * year month AVG('Adj Close) MAX('Adj Close) + * 1980 12 0.503218 0.595103 + * 1981 01 0.523289 0.570307 + * 1982 02 0.436504 0.475256 + * 1983 03 0.410516 0.442194 + * 1984 04 0.450090 0.483521 + * }}} + * @param numRows Number of rows to show + * @param truncate Whether truncate long strings. If true, strings more than 20 characters will + * be truncated and all cells will be aligned right + * + * @group action + * @since 1.5.0 + */ + def show(numRows: Int, truncate: Boolean): Unit = println(showString(numRows, truncate)) + /** * Returns a [[DataFrameNaFunctions]] for working with missing data. * {{{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index d06b9c5785527..50d324c0686fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -492,6 +492,27 @@ class DataFrameSuite extends QueryTest { testData.select($"*").show(1000) } + test("showString: truncate = [true, false]") { + val longString = Array.fill(21)("1").mkString + val df = ctx.sparkContext.parallelize(Seq("1", longString)).toDF() + val expectedAnswerForFalse = """+---------------------+ + ||_1 | + |+---------------------+ + ||1 | + ||111111111111111111111| + |+---------------------+ + |""".stripMargin + assert(df.showString(10, false) === expectedAnswerForFalse) + val expectedAnswerForTrue = """+--------------------+ + || _1| + |+--------------------+ + || 1| + ||11111111111111111...| + |+--------------------+ + |""".stripMargin + assert(df.showString(10, true) === expectedAnswerForTrue) + } + test("showString(negative)") { val expectedAnswer = """+---+-----+ ||key|value| From 5452457410ffe881773f2f2cdcdc752467b19720 Mon Sep 17 00:00:00 2001 From: Shuo Xiang Date: Mon, 29 Jun 2015 23:50:34 -0700 Subject: [PATCH 079/122] [SPARK-8551] [ML] Elastic net python code example Author: Shuo Xiang Closes #6946 from coderxiang/en-java-code-example and squashes the following commits: 7a4bdf8 [Shuo Xiang] address comments cddb02b [Shuo Xiang] add elastic net python example code f4fa534 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 6ad4865 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 180b496 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' aa0717d [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 5f109b4 [Shuo Xiang] Merge remote-tracking branch 'upstream/master' c5c5bfe [Shuo Xiang] Merge remote-tracking branch 'upstream/master' 98804c9 [Shuo Xiang] fix bug in topBykey and update test --- .../src/main/python/ml/logistic_regression.py | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 examples/src/main/python/ml/logistic_regression.py diff --git a/examples/src/main/python/ml/logistic_regression.py b/examples/src/main/python/ml/logistic_regression.py new file mode 100644 index 0000000000000..55afe1b207fe0 --- /dev/null +++ b/examples/src/main/python/ml/logistic_regression.py @@ -0,0 +1,67 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.ml.classification import LogisticRegression +from pyspark.mllib.evaluation import MulticlassMetrics +from pyspark.ml.feature import StringIndexer +from pyspark.mllib.util import MLUtils +from pyspark.sql import SQLContext + +""" +A simple example demonstrating a logistic regression with elastic net regularization Pipeline. +Run with: + bin/spark-submit examples/src/main/python/ml/logistic_regression.py +""" + +if __name__ == "__main__": + + if len(sys.argv) > 1: + print("Usage: logistic_regression", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonLogisticRegressionExample") + sqlContext = SQLContext(sc) + + # Load and parse the data file into a dataframe. + df = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF() + + # Map labels into an indexed column of labels in [0, numLabels) + stringIndexer = StringIndexer(inputCol="label", outputCol="indexedLabel") + si_model = stringIndexer.fit(df) + td = si_model.transform(df) + [training, test] = td.randomSplit([0.7, 0.3]) + + lr = LogisticRegression(maxIter=100, regParam=0.3).setLabelCol("indexedLabel") + lr.setElasticNetParam(0.8) + + # Fit the model + lrModel = lr.fit(training) + + predictionAndLabels = lrModel.transform(test).select("prediction", "indexedLabel") \ + .map(lambda x: (x.prediction, x.indexedLabel)) + + metrics = MulticlassMetrics(predictionAndLabels) + print("weighted f-measure %.3f" % metrics.weightedFMeasure()) + print("precision %s" % metrics.precision()) + print("recall %s" % metrics.recall()) + + sc.stop() From 2ed0c0ac4686ea779f98713978e37b97094edc1c Mon Sep 17 00:00:00 2001 From: Tim Ellison Date: Tue, 30 Jun 2015 13:49:52 +0100 Subject: [PATCH 080/122] [SPARK-7756] [CORE] More robust SSL options processing. Subset the enabled algorithms in an SSLOptions to the elements that are supported by the protocol provider. Update the list of ciphers in the sample config to include modern algorithms, and specify both Oracle and IBM names. In practice the user would either specify their own chosen cipher suites, or specify none, and delegate the decision to the provider. Author: Tim Ellison Closes #7043 from tellison/SSLEnhancements and squashes the following commits: 034efa5 [Tim Ellison] Ensure Java imports are grouped and ordered by package. 3797f8b [Tim Ellison] Remove unnecessary use of Option to improve clarity, and fix import style ordering. 4b5c89f [Tim Ellison] More robust SSL options processing. --- .../scala/org/apache/spark/SSLOptions.scala | 43 ++++++++++++++++--- .../org/apache/spark/SSLOptionsSuite.scala | 20 ++++++--- .../org/apache/spark/SSLSampleConfigs.scala | 24 ++++++++--- .../apache/spark/SecurityManagerSuite.scala | 21 ++++++--- 4 files changed, 85 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 2cdc167f85af0..32df42d57dbd6 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -17,7 +17,9 @@ package org.apache.spark -import java.io.File +import java.io.{File, FileInputStream} +import java.security.{KeyStore, NoSuchAlgorithmException} +import javax.net.ssl.{KeyManager, KeyManagerFactory, SSLContext, TrustManager, TrustManagerFactory} import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory} import org.eclipse.jetty.util.ssl.SslContextFactory @@ -38,7 +40,7 @@ import org.eclipse.jetty.util.ssl.SslContextFactory * @param trustStore a path to the trust-store file * @param trustStorePassword a password to access the trust-store file * @param protocol SSL protocol (remember that SSLv3 was compromised) supported by Java - * @param enabledAlgorithms a set of encryption algorithms to use + * @param enabledAlgorithms a set of encryption algorithms that may be used */ private[spark] case class SSLOptions( enabled: Boolean = false, @@ -48,7 +50,8 @@ private[spark] case class SSLOptions( trustStore: Option[File] = None, trustStorePassword: Option[String] = None, protocol: Option[String] = None, - enabledAlgorithms: Set[String] = Set.empty) { + enabledAlgorithms: Set[String] = Set.empty) + extends Logging { /** * Creates a Jetty SSL context factory according to the SSL settings represented by this object. @@ -63,7 +66,7 @@ private[spark] case class SSLOptions( trustStorePassword.foreach(sslContextFactory.setTrustStorePassword) keyPassword.foreach(sslContextFactory.setKeyManagerPassword) protocol.foreach(sslContextFactory.setProtocol) - sslContextFactory.setIncludeCipherSuites(enabledAlgorithms.toSeq: _*) + sslContextFactory.setIncludeCipherSuites(supportedAlgorithms.toSeq: _*) Some(sslContextFactory) } else { @@ -94,7 +97,7 @@ private[spark] case class SSLOptions( .withValue("akka.remote.netty.tcp.security.protocol", ConfigValueFactory.fromAnyRef(protocol.getOrElse(""))) .withValue("akka.remote.netty.tcp.security.enabled-algorithms", - ConfigValueFactory.fromIterable(enabledAlgorithms.toSeq)) + ConfigValueFactory.fromIterable(supportedAlgorithms.toSeq)) .withValue("akka.remote.netty.tcp.enable-ssl", ConfigValueFactory.fromAnyRef(true))) } else { @@ -102,6 +105,36 @@ private[spark] case class SSLOptions( } } + /* + * The supportedAlgorithms set is a subset of the enabledAlgorithms that + * are supported by the current Java security provider for this protocol. + */ + private val supportedAlgorithms: Set[String] = { + var context: SSLContext = null + try { + context = SSLContext.getInstance(protocol.orNull) + /* The set of supported algorithms does not depend upon the keys, trust, or + rng, although they will influence which algorithms are eventually used. */ + context.init(null, null, null) + } catch { + case npe: NullPointerException => + logDebug("No SSL protocol specified") + context = SSLContext.getDefault + case nsa: NoSuchAlgorithmException => + logDebug(s"No support for requested SSL protocol ${protocol.get}") + context = SSLContext.getDefault + } + + val providerAlgorithms = context.getServerSocketFactory.getSupportedCipherSuites.toSet + + // Log which algorithms we are discarding + (enabledAlgorithms &~ providerAlgorithms).foreach { cipher => + logDebug(s"Discarding unsupported cipher $cipher") + } + + enabledAlgorithms & providerAlgorithms + } + /** Returns a string representation of this SSLOptions with all the passwords masked. */ override def toString: String = s"SSLOptions{enabled=$enabled, " + s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " + diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 376481ba541fa..25b79bce6ab98 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import java.io.File +import javax.net.ssl.SSLContext import com.google.common.io.Files import org.apache.spark.util.Utils @@ -29,6 +30,15 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { val keyStorePath = new File(this.getClass.getResource("/keystore").toURI).getAbsolutePath val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + // Pick two cipher suites that the provider knows about + val sslContext = SSLContext.getInstance("TLSv1.2") + sslContext.init(null, null, null) + val algorithms = sslContext + .getServerSocketFactory + .getDefaultCipherSuites + .take(2) + .toSet + val conf = new SparkConf conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) @@ -36,9 +46,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", - "TLS_RSA_WITH_AES_128_CBC_SHA, TLS_RSA_WITH_AES_256_CBC_SHA") - conf.set("spark.ssl.protocol", "SSLv3") + conf.set("spark.ssl.enabledAlgorithms", algorithms.mkString(",")) + conf.set("spark.ssl.protocol", "TLSv1.2") val opts = SSLOptions.parse(conf, "spark.ssl") @@ -52,9 +61,8 @@ class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { assert(opts.trustStorePassword === Some("password")) assert(opts.keyStorePassword === Some("password")) assert(opts.keyPassword === Some("password")) - assert(opts.protocol === Some("SSLv3")) - assert(opts.enabledAlgorithms === - Set("TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA")) + assert(opts.protocol === Some("TLSv1.2")) + assert(opts.enabledAlgorithms === algorithms) } test("test resolving property with defaults specified ") { diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index 1a099da2c6c8e..33270bec6247c 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -25,6 +25,20 @@ object SSLSampleConfigs { this.getClass.getResource("/untrusted-keystore").toURI).getAbsolutePath val trustStorePath = new File(this.getClass.getResource("/truststore").toURI).getAbsolutePath + val enabledAlgorithms = + // A reasonable set of TLSv1.2 Oracle security provider suites + "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + + "TLS_RSA_WITH_AES_256_CBC_SHA256, " + + "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, " + + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + + "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, " + + // and their equivalent names in the IBM Security provider + "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384, " + + "SSL_RSA_WITH_AES_256_CBC_SHA256, " + + "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256, " + + "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256, " + + "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256" + def sparkSSLConfig(): SparkConf = { val conf = new SparkConf(loadDefaults = false) conf.set("spark.ssl.enabled", "true") @@ -33,9 +47,8 @@ object SSLSampleConfigs { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", - "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA") - conf.set("spark.ssl.protocol", "TLSv1") + conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) + conf.set("spark.ssl.protocol", "TLSv1.2") conf } @@ -47,9 +60,8 @@ object SSLSampleConfigs { conf.set("spark.ssl.keyPassword", "password") conf.set("spark.ssl.trustStore", trustStorePath) conf.set("spark.ssl.trustStorePassword", "password") - conf.set("spark.ssl.enabledAlgorithms", - "SSL_RSA_WITH_RC4_128_SHA, SSL_RSA_WITH_DES_CBC_SHA") - conf.set("spark.ssl.protocol", "TLSv1") + conf.set("spark.ssl.enabledAlgorithms", enabledAlgorithms) + conf.set("spark.ssl.protocol", "TLSv1.2") conf } diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index e9b64aa82a17a..f34aefca4eb18 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -127,6 +127,17 @@ class SecurityManagerSuite extends SparkFunSuite { test("ssl on setup") { val conf = SSLSampleConfigs.sparkSSLConfig() + val expectedAlgorithms = Set( + "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384", + "TLS_RSA_WITH_AES_256_CBC_SHA256", + "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256", + "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", + "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256", + "SSL_ECDHE_RSA_WITH_AES_256_CBC_SHA384", + "SSL_RSA_WITH_AES_256_CBC_SHA256", + "SSL_DHE_RSA_WITH_AES_256_CBC_SHA256", + "SSL_ECDHE_RSA_WITH_AES_128_CBC_SHA256", + "SSL_DHE_RSA_WITH_AES_128_CBC_SHA256") val securityManager = new SecurityManager(conf) @@ -143,9 +154,8 @@ class SecurityManagerSuite extends SparkFunSuite { assert(securityManager.fileServerSSLOptions.trustStorePassword === Some("password")) assert(securityManager.fileServerSSLOptions.keyStorePassword === Some("password")) assert(securityManager.fileServerSSLOptions.keyPassword === Some("password")) - assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1")) - assert(securityManager.fileServerSSLOptions.enabledAlgorithms === - Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + assert(securityManager.fileServerSSLOptions.protocol === Some("TLSv1.2")) + assert(securityManager.fileServerSSLOptions.enabledAlgorithms === expectedAlgorithms) assert(securityManager.akkaSSLOptions.trustStore.isDefined === true) assert(securityManager.akkaSSLOptions.trustStore.get.getName === "truststore") @@ -154,9 +164,8 @@ class SecurityManagerSuite extends SparkFunSuite { assert(securityManager.akkaSSLOptions.trustStorePassword === Some("password")) assert(securityManager.akkaSSLOptions.keyStorePassword === Some("password")) assert(securityManager.akkaSSLOptions.keyPassword === Some("password")) - assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1")) - assert(securityManager.akkaSSLOptions.enabledAlgorithms === - Set("SSL_RSA_WITH_RC4_128_SHA", "SSL_RSA_WITH_DES_CBC_SHA")) + assert(securityManager.akkaSSLOptions.protocol === Some("TLSv1.2")) + assert(securityManager.akkaSSLOptions.enabledAlgorithms === expectedAlgorithms) } test("ssl off setup") { From 08fab4843845136358f3a7251e8d90135126b419 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 30 Jun 2015 07:58:49 -0700 Subject: [PATCH 081/122] [SPARK-8590] [SQL] add code gen for ExtractValue TODO: use array instead of Seq as internal representation for `ArrayType` Author: Wenchen Fan Closes #6982 from cloud-fan/extract-value and squashes the following commits: e203bc1 [Wenchen Fan] address comments 4da0f0b [Wenchen Fan] some clean up f679969 [Wenchen Fan] fix bug e64f942 [Wenchen Fan] remove generic e3f8427 [Wenchen Fan] fix style and address comments fc694e8 [Wenchen Fan] add code gen for extract value --- .../catalyst/expressions/BoundAttribute.scala | 2 +- .../sql/catalyst/expressions/Expression.scala | 46 ++++-- .../catalyst/expressions/ExtractValue.scala | 76 ++++++++-- .../sql/catalyst/expressions/arithmetic.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 15 +- .../codegen/GenerateMutableProjection.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 13 +- .../sql/catalyst/expressions/predicates.scala | 3 - .../spark/sql/catalyst/expressions/sets.scala | 4 - .../spark/sql/catalyst/util/TypeUtils.scala | 2 +- .../expressions/ComplexTypeSuite.scala | 131 ++++++++++-------- 11 files changed, 199 insertions(+), 101 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 5db2fcfcb267b..dc0b4ac5cd9bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -47,7 +47,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) s""" boolean ${ev.isNull} = i.isNullAt($ordinal); ${ctx.javaType(dataType)} ${ev.primitive} = ${ev.isNull} ? - ${ctx.defaultValue(dataType)} : (${ctx.getColumn(dataType, ordinal)}); + ${ctx.defaultValue(dataType)} : (${ctx.getColumn("i", dataType, ordinal)}); """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index e5dc7b9b5c884..aed48921bdeb5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -179,9 +179,10 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def toString: String = s"($left $symbol $right)" override def isThreadSafe: Boolean = left.isThreadSafe && right.isThreadSafe + /** - * Short hand for generating binary evaluation code, which depends on two sub-evaluations of - * the same type. If either of the sub-expressions is null, the result of this computation + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation * is assumed to be null. * * @param f accepts two variable names and returns Java code to compute the output. @@ -190,15 +191,23 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express ctx: CodeGenContext, ev: GeneratedExpressionCode, f: (String, String) => String): String = { - // TODO: Right now some timestamp tests fail if we enforce this... - if (left.dataType != right.dataType) { - // log.warn(s"${left.dataType} != ${right.dataType}") - } + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s"$result = ${f(eval1, eval2)};" + }) + } + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val resultCode = f(eval1.primitive, eval2.primitive) - + val resultCode = f(ev.primitive, eval1.primitive, eval2.primitive) s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; @@ -206,7 +215,7 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express if (!${ev.isNull}) { ${eval2.code} if (!${eval2.isNull}) { - ${ev.primitive} = $resultCode; + $resultCode } else { ${ev.isNull} = true; } @@ -245,13 +254,26 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio ctx: CodeGenContext, ev: GeneratedExpressionCode, f: String => String): String = { + nullSafeCodeGen(ctx, ev, (result, eval) => { + s"$result = ${f(eval)};" + }) + } + + /** + * Called by unary expressions to generate a code block that returns null if its parent returns + * null, and if not not null, use `f` to generate the expression. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String) => String): String = { val eval = child.gen(ctx) - // reuse the previous isNull - ev.isNull = eval.isNull + val resultCode = f(ev.primitive, eval.primitive) eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.primitive} = ${f(eval.primitive)}; + $resultCode } """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index 4d7c95ffd1850..3020e7fc967f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -21,6 +21,7 @@ import scala.collection.Map import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.types._ object ExtractValue { @@ -38,7 +39,7 @@ object ExtractValue { def apply( child: Expression, extraction: Expression, - resolver: Resolver): ExtractValue = { + resolver: Resolver): Expression = { (child.dataType, extraction) match { case (StructType(fields), NonNullLiteral(v, StringType)) => @@ -73,7 +74,7 @@ object ExtractValue { def unapply(g: ExtractValue): Option[(Expression, Expression)] = { g match { case o: ExtractValueWithOrdinal => Some((o.child, o.ordinal)) - case _ => Some((g.child, null)) + case s: ExtractValueWithStruct => Some((s.child, null)) } } @@ -101,11 +102,11 @@ object ExtractValue { * Note: concrete extract value expressions are created only by `ExtractValue.apply`, * we don't need to do type check for them. */ -trait ExtractValue extends UnaryExpression { - self: Product => +trait ExtractValue { + self: Expression => } -abstract class ExtractValueWithStruct extends ExtractValue { +abstract class ExtractValueWithStruct extends UnaryExpression with ExtractValue { self: Product => def field: StructField @@ -125,6 +126,18 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int) val baseValue = child.eval(input).asInstanceOf[InternalRow] if (baseValue == null) null else baseValue(ordinal) } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, eval) => { + s""" + if ($eval.isNullAt($ordinal)) { + ${ev.isNull} = true; + } else { + $result = ${ctx.getColumn(eval, dataType, ordinal)}; + } + """ + }) + } } /** @@ -137,6 +150,7 @@ case class GetArrayStructFields( containsNull: Boolean) extends ExtractValueWithStruct { override def dataType: DataType = ArrayType(field.dataType, containsNull) + override def nullable: Boolean = child.nullable || containsNull || field.nullable override def eval(input: InternalRow): Any = { val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]] @@ -146,18 +160,39 @@ case class GetArrayStructFields( } } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val arraySeqClass = "scala.collection.mutable.ArraySeq" + // TODO: consider using Array[_] for ArrayType child to avoid + // boxing of primitives + nullSafeCodeGen(ctx, ev, (result, eval) => { + s""" + final int n = $eval.size(); + final $arraySeqClass values = new $arraySeqClass(n); + for (int j = 0; j < n; j++) { + InternalRow row = (InternalRow) $eval.apply(j); + if (row != null && !row.isNullAt($ordinal)) { + values.update(j, ${ctx.getColumn("row", field.dataType, ordinal)}); + } + } + $result = (${ctx.javaType(dataType)}) values; + """ + }) + } } -abstract class ExtractValueWithOrdinal extends ExtractValue { +abstract class ExtractValueWithOrdinal extends BinaryExpression with ExtractValue { self: Product => def ordinal: Expression + def child: Expression + + override def left: Expression = child + override def right: Expression = ordinal /** `Null` is returned for invalid ordinals. */ override def nullable: Boolean = true - override def foldable: Boolean = child.foldable && ordinal.foldable override def toString: String = s"$child[$ordinal]" - override def children: Seq[Expression] = child :: ordinal :: Nil override def eval(input: InternalRow): Any = { val value = child.eval(input) @@ -195,6 +230,19 @@ case class GetArrayItem(child: Expression, ordinal: Expression) baseValue(index) } } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s""" + final int index = (int)$eval2; + if (index >= $eval1.size() || index < 0) { + ${ev.isNull} = true; + } else { + $result = (${ctx.boxedType(dataType)})$eval1.apply(index); + } + """ + }) + } } /** @@ -209,4 +257,16 @@ case class GetMapValue(child: Expression, ordinal: Expression) val baseValue = value.asInstanceOf[Map[Any, _]] baseValue.get(ordinal).orNull } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (result, eval1, eval2) => { + s""" + if ($eval1.contains($eval2)) { + $result = (${ctx.boxedType(dataType)})$eval1.apply($eval2); + } else { + ${ev.isNull} = true; + } + """ + }) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 3d4d9e2d798f0..ae765c1653203 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -82,8 +82,6 @@ case class Abs(child: Expression) extends UnaryArithmetic { abstract class BinaryArithmetic extends BinaryExpression { self: Product => - /** Name of the function for this expression on a [[Decimal]] type. */ - def decimalMethod: String = "" override def dataType: DataType = left.dataType @@ -113,6 +111,10 @@ abstract class BinaryArithmetic extends BinaryExpression { } } + /** Name of the function for this expression on a [[Decimal]] type. */ + def decimalMethod: String = + sys.error("BinaryArithmetics must override either decimalMethod or genCode") + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 57e0bede5db20..bf6a6a124088e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -82,24 +82,24 @@ class CodeGenContext { /** * Returns the code to access a column in Row for a given DataType. */ - def getColumn(dataType: DataType, ordinal: Int): String = { + def getColumn(row: String, dataType: DataType, ordinal: Int): String = { val jt = javaType(dataType) if (isPrimitiveType(jt)) { - s"i.get${primitiveTypeName(jt)}($ordinal)" + s"$row.get${primitiveTypeName(jt)}($ordinal)" } else { - s"($jt)i.apply($ordinal)" + s"($jt)$row.apply($ordinal)" } } /** * Returns the code to update a column in Row for a given DataType. */ - def setColumn(dataType: DataType, ordinal: Int, value: String): String = { + def setColumn(row: String, dataType: DataType, ordinal: Int, value: String): String = { val jt = javaType(dataType) if (isPrimitiveType(jt)) { - s"set${primitiveTypeName(jt)}($ordinal, $value)" + s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" } else { - s"update($ordinal, $value)" + s"$row.update($ordinal, $value)" } } @@ -127,6 +127,9 @@ class CodeGenContext { case dt: DecimalType => decimalType case BinaryType => "byte[]" case StringType => stringType + case _: StructType => "InternalRow" + case _: ArrayType => s"scala.collection.Seq" + case _: MapType => s"scala.collection.Map" case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 64ef357a4f954..addb8023d9c0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -43,7 +43,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu if(${evaluationCode.isNull}) mutableRow.setNullAt($i); else - mutableRow.${ctx.setColumn(e.dataType, i, evaluationCode.primitive)}; + ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.primitive)}; """ }.mkString("\n") val code = s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index a022f3727bd58..da63f2fa970cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -78,17 +78,14 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) def funcName: String = name.toLowerCase override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val eval = child.gen(ctx) - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.primitive} = java.lang.Math.${funcName}(${eval.primitive}); + nullSafeCodeGen(ctx, ev, (result, eval) => { + s""" + ${ev.primitive} = java.lang.Math.${funcName}($eval); if (Double.valueOf(${ev.primitive}).isNaN()) { ${ev.isNull} = true; } - } - """ + """ + }) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 386cf6a8df6df..98cd5aa8148c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -69,10 +69,7 @@ trait PredicateHelper { expr.references.subsetOf(plan.outputSet) } - case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes { - override def foldable: Boolean = child.foldable - override def nullable: Boolean = child.nullable override def toString: String = s"NOT $child" override def expectedChildTypes: Seq[DataType] = Seq(BooleanType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala index efc6f50b78943..daa9f4403ffab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/sets.scala @@ -135,8 +135,6 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression { */ case class CombineSets(left: Expression, right: Expression) extends BinaryExpression { - override def nullable: Boolean = left.nullable || right.nullable - override def dataType: DataType = left.dataType override def symbol: String = "++=" @@ -185,8 +183,6 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres */ case class CountSet(child: Expression) extends UnaryExpression { - override def nullable: Boolean = child.nullable - override def dataType: DataType = LongType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 8656cc334d09f..3148309a2166f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.types._ /** - * Helper function to check for valid data types + * Helper functions to check for valid data types. */ object TypeUtils { def checkForNumericExpr(t: DataType, caller: String): TypeCheckResult = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index b80911e7257fc..3515d044b2f7e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -40,51 +40,42 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } test("GetArrayItem") { + val typeA = ArrayType(StringType) + val array = Literal.create(Seq("a", "b"), typeA) testIntegralDataTypes { convert => - val array = Literal.create(Seq("a", "b"), ArrayType(StringType)) checkEvaluation(GetArrayItem(array, Literal(convert(1))), "b") } + val nullArray = Literal.create(null, typeA) + val nullInt = Literal.create(null, IntegerType) + checkEvaluation(GetArrayItem(nullArray, Literal(1)), null) + checkEvaluation(GetArrayItem(array, nullInt), null) + checkEvaluation(GetArrayItem(nullArray, nullInt), null) + + val nestedArray = Literal.create(Seq(Seq(1)), ArrayType(ArrayType(IntegerType))) + checkEvaluation(GetArrayItem(nestedArray, Literal(0)), Seq(1)) } - test("CreateStruct") { - val row = InternalRow(1, 2, 3) - val c1 = 'a.int.at(0).as("a") - val c3 = 'c.int.at(2).as("c") - checkEvaluation(CreateStruct(Seq(c1, c3)), InternalRow(1, 3), row) + test("GetMapValue") { + val typeM = MapType(StringType, StringType) + val map = Literal.create(Map("a" -> "b"), typeM) + val nullMap = Literal.create(null, typeM) + val nullString = Literal.create(null, StringType) + + checkEvaluation(GetMapValue(map, Literal("a")), "b") + checkEvaluation(GetMapValue(map, nullString), null) + checkEvaluation(GetMapValue(nullMap, nullString), null) + checkEvaluation(GetMapValue(map, nullString), null) + + val nestedMap = Literal.create(Map("a" -> Map("b" -> "c")), MapType(StringType, typeM)) + checkEvaluation(GetMapValue(nestedMap, Literal("a")), Map("b" -> "c")) } - test("complex type") { - val row = create_row( - "^Ba*n", // 0 - null.asInstanceOf[UTF8String], // 1 - create_row("aa", "bb"), // 2 - Map("aa"->"bb"), // 3 - Seq("aa", "bb") // 4 - ) - - val typeS = StructType( - StructField("a", StringType, true) :: StructField("b", StringType, true) :: Nil - ) - val typeMap = MapType(StringType, StringType) - val typeArray = ArrayType(StringType) - - checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), - Literal("aa")), "bb", row) - checkEvaluation(GetMapValue(Literal.create(null, typeMap), Literal("aa")), null, row) - checkEvaluation( - GetMapValue(Literal.create(null, typeMap), Literal.create(null, StringType)), null, row) - checkEvaluation(GetMapValue(BoundReference(3, typeMap, true), - Literal.create(null, StringType)), null, row) - - checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), - Literal(1)), "bb", row) - checkEvaluation(GetArrayItem(Literal.create(null, typeArray), Literal(1)), null, row) - checkEvaluation( - GetArrayItem(Literal.create(null, typeArray), Literal.create(null, IntegerType)), null, row) - checkEvaluation(GetArrayItem(BoundReference(4, typeArray, true), - Literal.create(null, IntegerType)), null, row) - - def getStructField(expr: Expression, fieldName: String): ExtractValue = { + test("GetStructField") { + val typeS = StructType(StructField("a", IntegerType) :: Nil) + val struct = Literal.create(create_row(1), typeS) + val nullStruct = Literal.create(null, typeS) + + def getStructField(expr: Expression, fieldName: String): GetStructField = { expr.dataType match { case StructType(fields) => val field = fields.find(_.name == fieldName).get @@ -92,28 +83,58 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } } - def quickResolve(u: UnresolvedExtractValue): ExtractValue = { - ExtractValue(u.child, u.extraction, _ == _) - } + checkEvaluation(getStructField(struct, "a"), 1) + checkEvaluation(getStructField(nullStruct, "a"), null) + + val nestedStruct = Literal.create(create_row(create_row(1)), + StructType(StructField("a", typeS) :: Nil)) + checkEvaluation(getStructField(nestedStruct, "a"), create_row(1)) + + val typeS_fieldNotNullable = StructType(StructField("a", IntegerType, false) :: Nil) + val struct_fieldNotNullable = Literal.create(create_row(1), typeS_fieldNotNullable) + val nullStruct_fieldNotNullable = Literal.create(null, typeS_fieldNotNullable) + + assert(getStructField(struct_fieldNotNullable, "a").nullable === false) + assert(getStructField(struct, "a").nullable === true) + assert(getStructField(nullStruct_fieldNotNullable, "a").nullable === true) + assert(getStructField(nullStruct, "a").nullable === true) + } - checkEvaluation(getStructField(BoundReference(2, typeS, nullable = true), "a"), "aa", row) - checkEvaluation(getStructField(Literal.create(null, typeS), "a"), null, row) + test("GetArrayStructFields") { + val typeAS = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val arrayStruct = Literal.create(Seq(create_row(1)), typeAS) + val nullArrayStruct = Literal.create(null, typeAS) - val typeS_notNullable = StructType( - StructField("a", StringType, nullable = false) - :: StructField("b", StringType, nullable = false) :: Nil - ) + def getArrayStructFields(expr: Expression, fieldName: String): GetArrayStructFields = { + expr.dataType match { + case ArrayType(StructType(fields), containsNull) => + val field = fields.find(_.name == fieldName).get + GetArrayStructFields(expr, field, fields.indexOf(field), containsNull) + } + } + + checkEvaluation(getArrayStructFields(arrayStruct, "a"), Seq(1)) + checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null) + } - assert(getStructField(BoundReference(2, typeS, nullable = true), "a").nullable === true) - assert(getStructField(BoundReference(2, typeS_notNullable, nullable = false), "a").nullable - === false) + test("CreateStruct") { + val row = create_row(1, 2, 3) + val c1 = 'a.int.at(0).as("a") + val c3 = 'c.int.at(2).as("c") + checkEvaluation(CreateStruct(Seq(c1, c3)), create_row(1, 3), row) + } - assert(getStructField(Literal.create(null, typeS), "a").nullable === true) - assert(getStructField(Literal.create(null, typeS_notNullable), "a").nullable === true) + test("test dsl for complex type") { + def quickResolve(u: UnresolvedExtractValue): Expression = { + ExtractValue(u.child, u.extraction, _ == _) + } - checkEvaluation(quickResolve('c.map(typeMap).at(3).getItem("aa")), "bb", row) - checkEvaluation(quickResolve('c.array(typeArray.elementType).at(4).getItem(1)), "bb", row) - checkEvaluation(quickResolve('c.struct(typeS).at(2).getField("a")), "aa", row) + checkEvaluation(quickResolve('c.map(MapType(StringType, StringType)).at(0).getItem("a")), + "b", create_row(Map("a" -> "b"))) + checkEvaluation(quickResolve('c.array(StringType).at(0).getItem(1)), + "b", create_row(Seq("a", "b"))) + checkEvaluation(quickResolve('c.struct(StructField("a", IntegerType)).at(0).getField("a")), + 1, create_row(create_row(1))) } test("error message of ExtractValue") { From 865a834e51ac3074811a11fd99a36d942f7f7de8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 30 Jun 2015 08:08:15 -0700 Subject: [PATCH 082/122] [SPARK-8723] [SQL] improve divide and remainder code gen We can avoid execution of both left and right expression by null and zero check. Author: Wenchen Fan Closes #7111 from cloud-fan/cg and squashes the following commits: d6b12ef [Wenchen Fan] improve divide and remainder code gen --- .../sql/catalyst/expressions/arithmetic.scala | 54 ++++++++++++------- 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index ae765c1653203..5363b3556886a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -216,23 +216,32 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val test = if (left.dataType.isInstanceOf[DecimalType]) { + val isZero = if (dataType.isInstanceOf[DecimalType]) { s"${eval2.primitive}.isZero()" } else { s"${eval2.primitive} == 0" } - val method = if (left.dataType.isInstanceOf[DecimalType]) s".$decimalMethod" else s" $symbol " - val javaType = ctx.javaType(left.dataType) - eval1.code + eval2.code + - s""" + val javaType = ctx.javaType(dataType) + val divide = if (dataType.isInstanceOf[DecimalType]) { + s"${eval1.primitive}.$decimalMethod(${eval2.primitive})" + } else { + s"($javaType)(${eval1.primitive} $symbol ${eval2.primitive})" + } + s""" + ${eval2.code} boolean ${ev.isNull} = false; - ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; - if (${eval1.isNull} || ${eval2.isNull} || $test) { + $javaType ${ev.primitive} = ${ctx.defaultValue(javaType)}; + if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { - ${ev.primitive} = ($javaType) (${eval1.primitive}$method(${eval2.primitive})); + ${eval1.code} + if (${eval1.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = $divide; + } } - """ + """ } } @@ -273,23 +282,32 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) - val test = if (left.dataType.isInstanceOf[DecimalType]) { + val isZero = if (dataType.isInstanceOf[DecimalType]) { s"${eval2.primitive}.isZero()" } else { s"${eval2.primitive} == 0" } - val method = if (left.dataType.isInstanceOf[DecimalType]) s".$decimalMethod" else s" $symbol " - val javaType = ctx.javaType(left.dataType) - eval1.code + eval2.code + - s""" + val javaType = ctx.javaType(dataType) + val remainder = if (dataType.isInstanceOf[DecimalType]) { + s"${eval1.primitive}.$decimalMethod(${eval2.primitive})" + } else { + s"($javaType)(${eval1.primitive} $symbol ${eval2.primitive})" + } + s""" + ${eval2.code} boolean ${ev.isNull} = false; - ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)}; - if (${eval1.isNull} || ${eval2.isNull} || $test) { + $javaType ${ev.primitive} = ${ctx.defaultValue(javaType)}; + if (${eval2.isNull} || $isZero) { ${ev.isNull} = true; } else { - ${ev.primitive} = ($javaType) (${eval1.primitive}$method(${eval2.primitive})); + ${eval1.code} + if (${eval1.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = $remainder; + } } - """ + """ } } From a48e61915354d33fb98944a8eb5a5d48dd102041 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 30 Jun 2015 08:17:24 -0700 Subject: [PATCH 083/122] [SPARK-8680] [SQL] Slightly improve PropagateTypes JIRA: https://issues.apache.org/jira/browse/SPARK-8680 This PR slightly improve `PropagateTypes` in `HiveTypeCoercion`. It moves `q.inputSet` outside `q transformExpressions` instead calling `inputSet` multiple times. It also builds a map of attributes for looking attribute easily. Author: Liang-Chi Hsieh Closes #7087 from viirya/improve_propagatetypes and squashes the following commits: 5c314c1 [Liang-Chi Hsieh] For comments. 913f6ad [Liang-Chi Hsieh] Slightly improve PropagateTypes. --- .../catalyst/analysis/HiveTypeCoercion.scala | 30 ++++++++++--------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index c3d68197d64ac..e525ad623ff12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -131,20 +131,22 @@ trait HiveTypeCoercion { // Don't propagate types from unresolved children. case q: LogicalPlan if !q.childrenResolved => q - case q: LogicalPlan => q transformExpressions { - case a: AttributeReference => - q.inputSet.find(_.exprId == a.exprId) match { - // This can happen when a Attribute reference is born in a non-leaf node, for example - // due to a call to an external script like in the Transform operator. - // TODO: Perhaps those should actually be aliases? - case None => a - // Leave the same if the dataTypes match. - case Some(newType) if a.dataType == newType.dataType => a - case Some(newType) => - logDebug(s"Promoting $a to $newType in ${q.simpleString}}") - newType - } - } + case q: LogicalPlan => + val inputMap = q.inputSet.toSeq.map(a => (a.exprId, a)).toMap + q transformExpressions { + case a: AttributeReference => + inputMap.get(a.exprId) match { + // This can happen when a Attribute reference is born in a non-leaf node, for example + // due to a call to an external script like in the Transform operator. + // TODO: Perhaps those should actually be aliases? + case None => a + // Leave the same if the dataTypes match. + case Some(newType) if a.dataType == newType.dataType => a + case Some(newType) => + logDebug(s"Promoting $a to $newType in ${q.simpleString}}") + newType + } + } } } From 722aa5f48ec105bf23eee2361adddfe3a0cd6fc4 Mon Sep 17 00:00:00 2001 From: Shilei Date: Tue, 30 Jun 2015 09:49:58 -0700 Subject: [PATCH 084/122] [SPARK-8236] [SQL] misc functions: crc32 https://issues.apache.org/jira/browse/SPARK-8236 Author: Shilei Closes #7108 from qiansl127/Crc32 and squashes the following commits: 5477352 [Shilei] Change to AutoCastInputTypes 5f16e5d [Shilei] Add misc function crc32 --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../spark/sql/catalyst/expressions/misc.scala | 40 +++++++++++++++++++ .../expressions/MiscFunctionsSuite.scala | 8 ++++ .../org/apache/spark/sql/functions.scala | 16 ++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 11 +++++ 5 files changed, 76 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index b17457d3094c2..d53eaedda56b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -139,6 +139,7 @@ object FunctionRegistry { expression[Sha2]("sha2"), expression[Sha1]("sha1"), expression[Sha1]("sha"), + expression[Crc32]("crc32"), // aggregate functions expression[Average]("avg"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 27805bff293f4..a7bcbe46c339a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.security.MessageDigest import java.security.NoSuchAlgorithmException +import java.util.zip.CRC32 import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.analysis.TypeCheckResult @@ -168,3 +169,42 @@ case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTyp ) } } + +/** + * A function that computes a cyclic redundancy check value and returns it as a bigint + * For input of type [[BinaryType]] + */ +case class Crc32(child: Expression) + extends UnaryExpression with AutoCastInputTypes { + + override def dataType: DataType = LongType + + override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + + override def eval(input: InternalRow): Any = { + val value = child.eval(input) + if (value == null) { + null + } else { + val checksum = new CRC32 + checksum.update(value.asInstanceOf[Array[Byte]], 0, value.asInstanceOf[Array[Byte]].length) + checksum.getValue + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val value = child.gen(ctx) + val CRC32 = "java.util.zip.CRC32" + s""" + ${value.code} + boolean ${ev.isNull} = ${value.isNull}; + long ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${CRC32} checksum = new ${CRC32}(); + checksum.update(${value.primitive}, 0, ${value.primitive}.length); + ${ev.primitive} = checksum.getValue(); + } + """ + } + +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 36e636b5da6b8..b524d0af14a67 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -49,4 +49,12 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Sha2(Literal("ABC".getBytes), Literal.create(null, IntegerType)), null) checkEvaluation(Sha2(Literal.create(null, BinaryType), Literal.create(null, IntegerType)), null) } + + test("crc32") { + checkEvaluation(Crc32(Literal("ABC".getBytes)), 2743272264L) + checkEvaluation(Crc32(Literal.create(Array[Byte](1, 2, 3, 4, 5, 6), BinaryType)), + 2180413220L) + checkEvaluation(Crc32(Literal.create(null, BinaryType)), null) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 4d9a019058228..6331fe61052ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1466,6 +1466,22 @@ object functions { */ def sha2(columnName: String, numBits: Int): Column = sha2(Column(columnName), numBits) + /** + * Calculates the cyclic redundancy check value and returns the value as a bigint. + * + * @group misc_funcs + * @since 1.5.0 + */ + def crc32(e: Column): Column = Crc32(e.expr) + + /** + * Calculates the cyclic redundancy check value and returns the value as a bigint. + * + * @group misc_funcs + * @since 1.5.0 + */ + def crc32(columnName: String): Column = crc32(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index abfd47c811ed9..11a8767ead96c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -173,6 +173,17 @@ class DataFrameFunctionsSuite extends QueryTest { } } + test("misc crc32 function") { + val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") + checkAnswer( + df.select(crc32($"a"), crc32("b")), + Row(2743272264L, 2180413220L)) + + checkAnswer( + df.selectExpr("crc32(a)", "crc32(b)"), + Row(2743272264L, 2180413220L)) + } + test("string length function") { checkAnswer( nullStrings.select(strlen($"s"), strlen("s")), From 689da28a53cf720ae607a1a935093612a7001615 Mon Sep 17 00:00:00 2001 From: xuchenCN Date: Tue, 30 Jun 2015 10:05:51 -0700 Subject: [PATCH 085/122] [SPARK-8592] [CORE] CoarseGrainedExecutorBackend: Cannot register with driver => NPE Look detail of this issue at [SPARK-8592](https://issues.apache.org/jira/browse/SPARK-8592) **CoarseGrainedExecutorBackend** should exit when **RegisterExecutor** failed Author: xuchenCN Closes #7110 from xuchenCN/SPARK-8592 and squashes the following commits: 71e0077 [xuchenCN] [SPARK-8592] [CORE] CoarseGrainedExecutorBackend: Cannot register with driver => NPE --- .../apache/spark/executor/CoarseGrainedExecutorBackend.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index f3a26f54a81fb..34d4cfdca7732 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -66,7 +66,10 @@ private[spark] class CoarseGrainedExecutorBackend( case Success(msg) => Utils.tryLogNonFatalError { Option(self).foreach(_.send(msg)) // msg must be RegisteredExecutor } - case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) + case Failure(e) => { + logError(s"Cannot register with driver: $driverUrl", e) + System.exit(1) + } }(ThreadUtils.sameThread) } From ada384b785c663392a0b69fad5bfe7a0a0584ee0 Mon Sep 17 00:00:00 2001 From: Sean Owen Date: Tue, 30 Jun 2015 10:07:26 -0700 Subject: [PATCH 086/122] [SPARK-8437] [DOCS] Corrected: Using directory path without wildcard for filename slow for large number of files with wholeTextFiles and binaryFiles Note that 'dir/*' can be more efficient in some Hadoop FS implementations that 'dir/' (now fixed scaladoc by using HTML entity for *) Author: Sean Owen Closes #7126 from srowen/SPARK-8437.2 and squashes the following commits: 7bb45da [Sean Owen] Note that 'dir/*' can be more efficient in some Hadoop FS implementations that 'dir/' (now fixed scaladoc by using HTML entity for *) --- core/src/main/scala/org/apache/spark/SparkContext.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index b3c3bf3746e18..0e5a86f44e410 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -831,7 +831,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * }}} * * @note Small files are preferred, large file is also allowable, but may cause bad performance. - * + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files + * in a directory rather than `.../path/` or `.../path` * @param minPartitions A suggestion value of the minimal splitting number for input data. */ def wholeTextFiles( @@ -878,9 +879,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * (a-hdfs-path/part-nnnnn, its content) * }}} * - * @param minPartitions A suggestion value of the minimal splitting number for input data. - * * @note Small files are preferred; very large files may cause bad performance. + * @note On some filesystems, `.../path/*` can be a more efficient way to read all files + * in a directory rather than `.../path/` or `.../path` + * @param minPartitions A suggestion value of the minimal splitting number for input data. */ @Experimental def binaryFiles( From 45281664e0d3b22cd63660ca8ad6dd574f10e21f Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 30 Jun 2015 10:25:59 -0700 Subject: [PATCH 087/122] [SPARK-4127] [MLLIB] [PYSPARK] Python bindings for StreamingLinearRegressionWithSGD Python bindings for StreamingLinearRegressionWithSGD Author: MechCoder Closes #6744 from MechCoder/spark-4127 and squashes the following commits: d8f6457 [MechCoder] Moved StreamingLinearAlgorithm to pyspark.mllib.regression d47cc24 [MechCoder] Inherit from StreamingLinearAlgorithm 1b4ddd6 [MechCoder] minor 4de6c68 [MechCoder] Minor refactor 5e85a3b [MechCoder] Add tests for simultaneous training and prediction fb27889 [MechCoder] Add example and docs 505380b [MechCoder] Add tests d42bdae [MechCoder] [SPARK-4127] Python bindings for StreamingLinearRegressionWithSGD --- docs/mllib-linear-methods.md | 52 +++++++++++ python/pyspark/mllib/classification.py | 50 +--------- python/pyspark/mllib/regression.py | 90 ++++++++++++++++++ python/pyspark/mllib/tests.py | 124 ++++++++++++++++++++++++- 4 files changed, 269 insertions(+), 47 deletions(-) diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md index 3dc8cc902fa72..2a2a7c13186d8 100644 --- a/docs/mllib-linear-methods.md +++ b/docs/mllib-linear-methods.md @@ -768,6 +768,58 @@ will get better! +
+ +First, we import the necessary classes for parsing our input data and creating the model. + +{% highlight python %} +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.regression import StreamingLinearRegressionWithSGD +{% endhighlight %} + +Then we make input streams for training and testing data. We assume a StreamingContext `ssc` +has already been created, see [Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) +for more info. For this example, we use labeled points in training and testing streams, +but in practice you will likely want to use unlabeled vectors for test data. + +{% highlight python %} +def parse(lp): + label = float(lp[lp.find('(') + 1: lp.find(',')]) + vec = Vectors.dense(lp[lp.find('[') + 1: lp.find(']')].split(',')) + return LabeledPoint(label, vec) + +trainingData = ssc.textFileStream("/training/data/dir").map(parse).cache() +testData = ssc.textFileStream("/testing/data/dir").map(parse) +{% endhighlight %} + +We create our model by initializing the weights to 0 + +{% highlight python %} +numFeatures = 3 +model = StreamingLinearRegressionWithSGD() +model.setInitialWeights([0.0, 0.0, 0.0]) +{% endhighlight %} + +Now we register the streams for training and testing and start the job. + +{% highlight python %} +model.trainOn(trainingData) +print(model.predictOnValues(testData.map(lambda lp: (lp.label, lp.features)))) + +ssc.start() +ssc.awaitTermination() +{% endhighlight %} + +We can now save text files with data to the training or testing folders. +Each line should be a data point formatted as `(y,[x1,x2,x3])` where `y` is the label +and `x1,x2,x3` are the features. Anytime a text file is placed in `/training/data/dir` +the model will update. Anytime a text file is placed in `/testing/data/dir` you will see predictions. +As you feed more data to the training directory, the predictions +will get better! + +
+ diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index 735d45ba03d27..8f27c446a66e8 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -24,7 +24,9 @@ from pyspark.streaming import DStream from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector -from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper +from pyspark.mllib.regression import ( + LabeledPoint, LinearModel, _regression_train_wrapper, + StreamingLinearAlgorithm) from pyspark.mllib.util import Saveable, Loader, inherit_doc @@ -585,55 +587,13 @@ def train(cls, data, lambda_=1.0): return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta)) -class StreamingLinearAlgorithm(object): - """ - Base class that has to be inherited by any StreamingLinearAlgorithm. - - Prevents reimplementation of methods predictOn and predictOnValues. - """ - def __init__(self, model): - self._model = model - - def latestModel(self): - """ - Returns the latest model. - """ - return self._model - - def _validate(self, dstream): - if not isinstance(dstream, DStream): - raise TypeError( - "dstream should be a DStream object, got %s" % type(dstream)) - if not self._model: - raise ValueError( - "Model must be intialized using setInitialWeights") - - def predictOn(self, dstream): - """ - Make predictions on a dstream. - - :return: Transformed dstream object. - """ - self._validate(dstream) - return dstream.map(lambda x: self._model.predict(x)) - - def predictOnValues(self, dstream): - """ - Make predictions on a keyed dstream. - - :return: Transformed dstream object. - """ - self._validate(dstream) - return dstream.mapValues(lambda x: self._model.predict(x)) - - @inherit_doc class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm): """ - Run LogisticRegression with SGD on a stream of data. + Run LogisticRegression with SGD on a batch of data. The weights obtained at the end of training a stream are used as initial - weights for the next stream. + weights for the next batch. :param stepSize: Step size for each iteration of gradient descent. :param numIterations: Number of iterations run for each batch of data. diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 5ddbbee4babdd..8e90adee5f4c2 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -19,6 +19,7 @@ from numpy import array from pyspark import RDD +from pyspark.streaming.dstream import DStream from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc from pyspark.mllib.linalg import SparseVector, Vectors, _convert_to_vector from pyspark.mllib.util import Saveable, Loader @@ -570,6 +571,95 @@ def train(cls, data, isotonic=True): return IsotonicRegressionModel(boundaries.toArray(), predictions.toArray(), isotonic) +class StreamingLinearAlgorithm(object): + """ + Base class that has to be inherited by any StreamingLinearAlgorithm. + + Prevents reimplementation of methods predictOn and predictOnValues. + """ + def __init__(self, model): + self._model = model + + def latestModel(self): + """ + Returns the latest model. + """ + return self._model + + def _validate(self, dstream): + if not isinstance(dstream, DStream): + raise TypeError( + "dstream should be a DStream object, got %s" % type(dstream)) + if not self._model: + raise ValueError( + "Model must be intialized using setInitialWeights") + + def predictOn(self, dstream): + """ + Make predictions on a dstream. + + :return: Transformed dstream object. + """ + self._validate(dstream) + return dstream.map(lambda x: self._model.predict(x)) + + def predictOnValues(self, dstream): + """ + Make predictions on a keyed dstream. + + :return: Transformed dstream object. + """ + self._validate(dstream) + return dstream.mapValues(lambda x: self._model.predict(x)) + + +@inherit_doc +class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm): + """ + Run LinearRegression with SGD on a batch of data. + + The problem minimized is (1 / n_samples) * (y - weights'X)**2. + After training on a batch of data, the weights obtained at the end of + training are used as initial weights for the next batch. + + :param: stepSize Step size for each iteration of gradient descent. + :param: numIterations Total number of iterations run. + :param: miniBatchFraction Fraction of data on which SGD is run for each + iteration. + """ + def __init__(self, stepSize=0.1, numIterations=50, miniBatchFraction=1.0): + self.stepSize = stepSize + self.numIterations = numIterations + self.miniBatchFraction = miniBatchFraction + self._model = None + super(StreamingLinearRegressionWithSGD, self).__init__( + model=self._model) + + def setInitialWeights(self, initialWeights): + """ + Set the initial value of weights. + + This must be set before running trainOn and predictOn + """ + initialWeights = _convert_to_vector(initialWeights) + self._model = LinearRegressionModel(initialWeights, 0) + return self + + def trainOn(self, dstream): + """Train the model on the incoming dstream.""" + self._validate(dstream) + + def update(rdd): + # LinearRegressionWithSGD.train raises an error for an empty RDD. + if not rdd.isEmpty(): + self._model = LinearRegressionWithSGD.train( + rdd, self.numIterations, self.stepSize, + self.miniBatchFraction, self._model.weights, + self._model.intercept) + + dstream.foreachRDD(update) + + def _test(): import doctest from pyspark import SparkContext diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index cd80c3e07a4f7..f0091d6faccce 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -27,8 +27,9 @@ from shutil import rmtree from numpy import ( - array, array_equal, zeros, inf, random, exp, dot, all, mean) + array, array_equal, zeros, inf, random, exp, dot, all, mean, abs) from numpy import sum as array_sum + from py4j.protocol import Py4JJavaError if sys.version_info[:2] <= (2, 6): @@ -45,8 +46,8 @@ from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT -from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD +from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics from pyspark.mllib.feature import Word2Vec @@ -56,6 +57,7 @@ from pyspark.serializers import PickleSerializer from pyspark.streaming import StreamingContext from pyspark.sql import SQLContext +from pyspark.streaming import StreamingContext _have_scipy = False try: @@ -1170,6 +1172,124 @@ def collect_errors(rdd): self.assertTrue(errors[1] - errors[-1] > 0.3) +class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): + + def assertArrayAlmostEqual(self, array1, array2, dec): + for i, j in array1, array2: + self.assertAlmostEqual(i, j, dec) + + def test_parameter_accuracy(self): + """Test that coefs are predicted accurately by fitting on toy data.""" + + # Test that fitting (10*X1 + 10*X2), (X1, X2) gives coefficients + # (10, 10) + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0, 0.0]) + xMean = [0.0, 0.0] + xVariance = [1.0 / 3.0, 1.0 / 3.0] + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0, 10.0], xMean, xVariance, 100, 42 + i, 0.1) + batches.append(sc.parallelize(batch)) + + input_stream = self.ssc.queueStream(batches) + t = time() + slr.trainOn(input_stream) + self.ssc.start() + self._ssc_wait(t, 10, 0.01) + self.assertArrayAlmostEqual( + slr.latestModel().weights.array, [10., 10.], 1) + self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1) + + def test_parameter_convergence(self): + """Test that the model parameters improve with streaming data.""" + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) + batches.append(sc.parallelize(batch)) + + model_weights = [] + input_stream = self.ssc.queueStream(batches) + input_stream.foreachRDD( + lambda x: model_weights.append(slr.latestModel().weights[0])) + t = time() + slr.trainOn(input_stream) + self.ssc.start() + self._ssc_wait(t, 10, 0.01) + + model_weights = array(model_weights) + diff = model_weights[1:] - model_weights[:-1] + self.assertTrue(all(diff >= -0.1)) + + def test_prediction(self): + """Test prediction on a model with weights already set.""" + # Create a model with initial Weights equal to coefs + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([10.0, 10.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0, 10.0], [0.0, 0.0], [1.0 / 3.0, 1.0 / 3.0], + 100, 42 + i, 0.1) + batches.append( + sc.parallelize(batch).map(lambda lp: (lp.label, lp.features))) + + input_stream = self.ssc.queueStream(batches) + t = time() + output_stream = slr.predictOnValues(input_stream) + samples = [] + output_stream.foreachRDD(lambda x: samples.append(x.collect())) + + self.ssc.start() + self._ssc_wait(t, 5, 0.01) + + # Test that mean absolute error on each batch is less than 0.1 + for batch in samples: + true, predicted = zip(*batch) + self.assertTrue(mean(abs(array(true) - array(predicted))) < 0.1) + + def test_train_prediction(self): + """Test that error on test data improves as model is trained.""" + slr = StreamingLinearRegressionWithSGD(stepSize=0.2, numIterations=25) + slr.setInitialWeights([0.0]) + + # Create ten batches with 100 sample points in each. + batches = [] + for i in range(10): + batch = LinearDataGenerator.generateLinearInput( + 0.0, [10.0], [0.0], [1.0 / 3.0], 100, 42 + i, 0.1) + batches.append(sc.parallelize(batch)) + + predict_batches = [ + b.map(lambda lp: (lp.label, lp.features)) for b in batches] + mean_absolute_errors = [] + + def func(rdd): + true, predicted = zip(*rdd.collect()) + mean_absolute_errors.append(mean(abs(true) - abs(predicted))) + + model_weights = [] + input_stream = self.ssc.queueStream(batches) + output_stream = self.ssc.queueStream(predict_batches) + t = time() + slr.trainOn(input_stream) + output_stream = slr.predictOnValues(output_stream) + output_stream.foreachRDD(func) + self.ssc.start() + self._ssc_wait(t, 10, 0.01) + self.assertTrue(mean_absolute_errors[1] - mean_absolute_errors[-1] > 2) + + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") From 5fa0863626aaf5a9a41756a0b1ec82bddccbf067 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 30 Jun 2015 10:27:29 -0700 Subject: [PATCH 088/122] [SPARK-8679] [PYSPARK] [MLLIB] Default values in Pipeline API should be immutable It might be dangerous to have a mutable as value for default param. (http://stackoverflow.com/a/11416002/1170730) e.g def func(example, f={}): f[example] = 1 return f func(2) {2: 1} func(3) {2:1, 3:1} mengxr Author: MechCoder Closes #7058 from MechCoder/pipeline_api_playground and squashes the following commits: 40a5eb2 [MechCoder] copy 95f7ff2 [MechCoder] [SPARK-8679] [PySpark] [MLlib] Default values in Pipeline API should be immutable --- python/pyspark/ml/pipeline.py | 24 ++++++++++++++++++------ python/pyspark/ml/wrapper.py | 4 +++- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a563024b2cdcb..9889f56cac9e4 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -42,7 +42,7 @@ def _fit(self, dataset): """ raise NotImplementedError() - def fit(self, dataset, params={}): + def fit(self, dataset, params=None): """ Fits a model to the input dataset with optional parameters. @@ -54,6 +54,8 @@ def fit(self, dataset, params={}): list of models. :returns: fitted model(s) """ + if params is None: + params = dict() if isinstance(params, (list, tuple)): return [self.fit(dataset, paramMap) for paramMap in params] elif isinstance(params, dict): @@ -86,7 +88,7 @@ def _transform(self, dataset): """ raise NotImplementedError() - def transform(self, dataset, params={}): + def transform(self, dataset, params=None): """ Transforms the input dataset with optional parameters. @@ -96,6 +98,8 @@ def transform(self, dataset, params={}): params. :returns: transformed dataset """ + if params is None: + params = dict() if isinstance(params, dict): if params: return self.copy(params,)._transform(dataset) @@ -135,10 +139,12 @@ class Pipeline(Estimator): """ @keyword_only - def __init__(self, stages=[]): + def __init__(self, stages=None): """ __init__(self, stages=[]) """ + if stages is None: + stages = [] super(Pipeline, self).__init__() #: Param for pipeline stages. self.stages = Param(self, "stages", "pipeline stages") @@ -162,11 +168,13 @@ def getStages(self): return self._paramMap[self.stages] @keyword_only - def setParams(self, stages=[]): + def setParams(self, stages=None): """ setParams(self, stages=[]) Sets params for Pipeline. """ + if stages is None: + stages = [] kwargs = self.setParams._input_kwargs return self._set(**kwargs) @@ -195,7 +203,9 @@ def _fit(self, dataset): transformers.append(stage) return PipelineModel(transformers) - def copy(self, extra={}): + def copy(self, extra=None): + if extra is None: + extra = dict() that = Params.copy(self, extra) stages = [stage.copy(extra) for stage in that.getStages()] return that.setStages(stages) @@ -216,6 +226,8 @@ def _transform(self, dataset): dataset = t.transform(dataset) return dataset - def copy(self, extra={}): + def copy(self, extra=None): + if extra is None: + extra = dict() stages = [stage.copy(extra) for stage in self.stages] return PipelineModel(stages) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 7b0893e2cdadc..253705bde913e 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -166,7 +166,7 @@ def __init__(self, java_model): self._java_obj = java_model self.uid = java_model.uid() - def copy(self, extra={}): + def copy(self, extra=None): """ Creates a copy of this instance with the same uid and some extra params. This implementation first calls Params.copy and @@ -175,6 +175,8 @@ def copy(self, extra={}): :param extra: Extra parameters to copy to the new instance :return: Copy of this instance """ + if extra is None: + extra = dict() that = super(JavaModel, self).copy(extra) that._java_obj = self._java_obj.copy(self._empty_java_param_map()) that._transfer_params_to_java() From fbb267ed6fe799a58f88c2fba2d41e954e5f1547 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 30 Jun 2015 10:48:49 -0700 Subject: [PATCH 089/122] [SPARK-8713] Make codegen thread safe Codegen takes three steps: 1. Take a list of expressions, convert them into Java source code and a list of expressions that don't not support codegen (fallback to interpret mode). 2. Compile the Java source into Java class (bytecode) 3. Using the Java class and the list of expression to build a Projection. Currently, we cache the whole three steps, the key is a list of expression, result is projection. Because some of expressions (which may not thread-safe, for example, Random) will be hold by the Projection, the projection maybe not thread safe. This PR change to only cache the second step, then we can build projection using codegen even some expressions are not thread-safe, because the cache will not hold any expression anymore. cc marmbrus rxin JoshRosen Author: Davies Liu Closes #7101 from davies/codegen_safe and squashes the following commits: 7dd41f1 [Davies Liu] Merge branch 'master' of github.com:apache/spark into codegen_safe 847bd08 [Davies Liu] don't use scala.refect 4ddaaed [Davies Liu] Merge branch 'master' of github.com:apache/spark into codegen_safe 1793cf1 [Davies Liu] make codegen thread safe --- .../sql/catalyst/expressions/Expression.scala | 14 ----------- .../sql/catalyst/expressions/ScalaUDF.scala | 3 --- .../expressions/codegen/CodeGenerator.scala | 25 ++++++++++--------- .../codegen/GenerateOrdering.scala | 9 +++---- .../codegen/GenerateProjection.scala | 7 +++--- .../expressions/namedExpressions.scala | 2 -- .../catalyst/expressions/nullFunctions.scala | 2 -- .../spark/sql/execution/SparkPlan.scala | 6 ++--- .../MonotonicallyIncreasingID.scala | 2 -- .../apache/spark/sql/sources/commands.scala | 2 +- .../apache/spark/sql/sources/interfaces.scala | 2 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 4 --- 12 files changed, 24 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index aed48921bdeb5..b5063f32fa529 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -60,14 +60,6 @@ abstract class Expression extends TreeNode[Expression] { /** Returns the result of evaluating this expression on a given input Row */ def eval(input: InternalRow = null): Any - /** - * Return true if this expression is thread-safe, which means it could be used by multiple - * threads in the same time. - * - * An expression that is not thread-safe can not be cached and re-used, especially for codegen. - */ - def isThreadSafe: Boolean = true - /** * Returns an [[GeneratedExpressionCode]], which contains Java source code that * can be used to generate the result of evaluating the expression on an input row. @@ -76,9 +68,6 @@ abstract class Expression extends TreeNode[Expression] { * @return [[GeneratedExpressionCode]] */ def gen(ctx: CodeGenContext): GeneratedExpressionCode = { - if (!isThreadSafe) { - throw new Exception(s"$this is not thread-safe, can not be used in codegen") - } val isNull = ctx.freshName("isNull") val primitive = ctx.freshName("primitive") val ve = GeneratedExpressionCode("", isNull, primitive) @@ -178,8 +167,6 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express override def toString: String = s"($left $symbol $right)" - override def isThreadSafe: Boolean = left.isThreadSafe && right.isThreadSafe - /** * Short hand for generating binary evaluation code. * If either of the sub-expressions is null, the result of this computation @@ -237,7 +224,6 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio override def foldable: Boolean = child.foldable override def nullable: Boolean = child.nullable - override def isThreadSafe: Boolean = child.isThreadSafe /** * Called by unary expressions to generate a code block that returns null if its parent returns diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index dbb4381d54c4f..ebabb6f117851 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -956,7 +956,4 @@ case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expressi // scalastyle:on private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) override def eval(input: InternalRow): Any = converter(f(input)) - - // TODO(davies): make ScalaUDF work with codegen - override def isThreadSafe: Boolean = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index bf6a6a124088e..a64027e48a00b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -235,11 +235,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin /** * Compile the Java source code into a Java class, using Janino. - * - * It will track the time used to compile */ protected def compile(code: String): GeneratedClass = { - val startTime = System.nanoTime() + cache.get(code) + } + + /** + * Compile the Java source code into a Java class, using Janino. + */ + private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() evaluator.setParentClassLoader(getClass.getClassLoader) evaluator.setDefaultImports(Array("org.apache.spark.sql.catalyst.InternalRow")) @@ -251,9 +255,6 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin logError(s"failed to compile:\n $code", e) throw e } - val endTime = System.nanoTime() - def timeMs: Double = (endTime - startTime).toDouble / 1000000 - logDebug(s"Code (${code.size} bytes) compiled in $timeMs ms") evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass] } @@ -266,16 +267,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * automatically, in order to constrain its memory footprint. Note that this cache does not use * weak keys/values and thus does not respond to memory pressure. */ - protected val cache = CacheBuilder.newBuilder() + private val cache = CacheBuilder.newBuilder() .maximumSize(100) .build( - new CacheLoader[InType, OutType]() { - override def load(in: InType): OutType = { + new CacheLoader[String, GeneratedClass]() { + override def load(code: String): GeneratedClass = { val startTime = System.nanoTime() - val result = create(in) + val result = doCompile(code) val endTime = System.nanoTime() def timeMs: Double = (endTime - startTime).toDouble / 1000000 - logInfo(s"Code generated expression $in in $timeMs ms") + logInfo(s"Code generated in $timeMs ms") result } }) @@ -285,7 +286,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin generate(bind(expressions, inputSchema)) /** Generates the requested evaluator given already bound expression(s). */ - def generate(expressions: InType): OutType = cache.get(canonicalize(expressions)) + def generate(expressions: InType): OutType = create(canonicalize(expressions)) /** * Create a new codegen context for expression evaluator, used to store those diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 7ed2c5addec9b..97cb16045ae4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -38,7 +38,6 @@ class BaseOrdering extends Ordering[InternalRow] { */ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalRow]] with Logging { - import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[SortOrder]): Seq[SortOrder] = in.map(ExpressionCanonicalizer.execute(_).asInstanceOf[SortOrder]) @@ -47,8 +46,6 @@ object GenerateOrdering in.map(BindReferences.bindReference(_, inputSchema)) protected def create(ordering: Seq[SortOrder]): Ordering[InternalRow] = { - val a = newTermName("a") - val b = newTermName("b") val ctx = newCodeGenContext() val comparisons = ordering.zipWithIndex.map { case (order, i) => @@ -56,9 +53,9 @@ object GenerateOrdering val evalB = order.child.gen(ctx) val asc = order.direction == Ascending s""" - i = $a; + i = a; ${evalA.code} - i = $b; + i = b; ${evalB.code} if (${evalA.isNull} && ${evalB.isNull}) { // Nothing @@ -80,7 +77,7 @@ object GenerateOrdering return new SpecificOrdering(expr); } - class SpecificOrdering extends ${typeOf[BaseOrdering]} { + class SpecificOrdering extends ${classOf[BaseOrdering].getName} { private $exprType[] expressions = null; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 39d32b78cc14a..5be47175fa7f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -32,7 +32,6 @@ abstract class BaseProject extends Projection {} * primitive values. */ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { - import scala.reflect.runtime.universe._ protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -157,7 +156,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { return new SpecificProjection(expr); } - class SpecificProjection extends ${typeOf[BaseProject]} { + class SpecificProjection extends ${classOf[BaseProject].getName} { private $exprType[] expressions = null; public SpecificProjection($exprType[] expr) { @@ -170,7 +169,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { } } - final class SpecificRow extends ${typeOf[MutableRow]} { + final class SpecificRow extends ${classOf[MutableRow].getName} { $columns @@ -224,7 +223,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public InternalRow copy() { Object[] arr = new Object[${expressions.length}]; ${copyColumns} - return new ${typeOf[GenericInternalRow]}(arr); + return new ${classOf[GenericInternalRow].getName}(arr); } } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 6f56a9ec7beb5..81ebda3060c51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -117,8 +117,6 @@ case class Alias(child: Expression, name: String)( override def eval(input: InternalRow): Any = child.eval(input) - override def isThreadSafe: Boolean = child.isThreadSafe - override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 5d5911403ece1..78be2824347d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -51,8 +51,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression { result } - override def isThreadSafe: Boolean = children.forall(_.isThreadSafe) - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { s""" boolean ${ev.isNull} = true; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 47f56b2b7ebe6..7739a9f949c77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -156,7 +156,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ expressions: Seq[Expression], inputSchema: Seq[Attribute]): Projection = { log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled && expressions.forall(_.isThreadSafe)) { + if (codegenEnabled) { GenerateProjection.generate(expressions, inputSchema) } else { new InterpretedProjection(expressions, inputSchema) @@ -168,7 +168,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ inputSchema: Seq[Attribute]): () => MutableProjection = { log.debug( s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if(codegenEnabled && expressions.forall(_.isThreadSafe)) { + if(codegenEnabled) { GenerateMutableProjection.generate(expressions, inputSchema) } else { () => new InterpretedMutableProjection(expressions, inputSchema) @@ -178,7 +178,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def newPredicate( expression: Expression, inputSchema: Seq[Attribute]): (InternalRow) => Boolean = { - if (codegenEnabled && expression.isThreadSafe) { + if (codegenEnabled) { GeneratePredicate.generate(expression, inputSchema) } else { InterpretedPredicate.create(expression, inputSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala index 3b217348b7b7a..68914cf85cb50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/expressions/MonotonicallyIncreasingID.scala @@ -48,6 +48,4 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression { count += 1 (TaskContext.get().partitionId().toLong << 33) + currentCount } - - override def isThreadSafe: Boolean = false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 54c8eeb41a8ea..42b51caab5ce9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -270,7 +270,7 @@ private[sql] case class InsertIntoHadoopFsRelation( inputSchema: Seq[Attribute]): Projection = { log.debug( s"Creating Projection: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled") - if (codegenEnabled && expressions.forall(_.isThreadSafe)) { + if (codegenEnabled) { GenerateProjection.generate(expressions, inputSchema) } else { new InterpretedProjection(expressions, inputSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 7005c7079af91..0b875304f9b0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -591,7 +591,7 @@ abstract class HadoopFsRelation private[sql](maybePartitionSpec: Option[Partitio rdd.map(_.asInstanceOf[InternalRow]) } converted.mapPartitions { rows => - val buildProjection = if (codegenEnabled && requiredOutput.forall(_.isThreadSafe)) { + val buildProjection = if (codegenEnabled) { GenerateMutableProjection.generate(requiredOutput, dataSchema.toAttributes) } else { () => new InterpretedMutableProjection(requiredOutput, dataSchema.toAttributes) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index d7827d56ca8c5..4dea561ae5f60 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -120,8 +120,6 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre @transient protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) - override def isThreadSafe: Boolean = false - // TODO: Finish input output types. override def eval(input: InternalRow): Any = { unwrap( @@ -180,8 +178,6 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr lazy val dataType: DataType = inspectorToDataType(returnInspector) - override def isThreadSafe: Boolean = false - override def eval(input: InternalRow): Any = { returnInspector // Make sure initialized. From 9213f73a8ea09ae343af825a6b576c212cf4a0c7 Mon Sep 17 00:00:00 2001 From: Tijo Thomas Date: Tue, 30 Jun 2015 10:50:45 -0700 Subject: [PATCH 090/122] [SPARK-8615] [DOCUMENTATION] Fixed Sample deprecated code Modified the deprecated jdbc api in the documentation. Author: Tijo Thomas Closes #7039 from tijoparacka/JIRA_8615 and squashes the following commits: 6e73b8a [Tijo Thomas] Reverted new lines 4042fcf [Tijo Thomas] updated to sql documentation a27949c [Tijo Thomas] Fixed Sample deprecated code --- docs/sql-programming-guide.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2786e3d2cd6bf..88c96a9a095b3 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1773,9 +1773,9 @@ the Data Sources API. The following options are supported:
{% highlight scala %} -val jdbcDF = sqlContext.load("jdbc", Map( - "url" -> "jdbc:postgresql:dbserver", - "dbtable" -> "schema.tablename")) +val jdbcDF = sqlContext.read.format("jdbc").options( + Map("url" -> "jdbc:postgresql:dbserver", + "dbtable" -> "schema.tablename")).load() {% endhighlight %}
@@ -1788,7 +1788,7 @@ Map options = new HashMap(); options.put("url", "jdbc:postgresql:dbserver"); options.put("dbtable", "schema.tablename"); -DataFrame jdbcDF = sqlContext.load("jdbc", options) +DataFrame jdbcDF = sqlContext.read().format("jdbc"). options(options).load(); {% endhighlight %} @@ -1798,7 +1798,7 @@ DataFrame jdbcDF = sqlContext.load("jdbc", options) {% highlight python %} -df = sqlContext.load(source="jdbc", url="jdbc:postgresql:dbserver", dbtable="schema.tablename") +df = sqlContext.read.format('jdbc').options(url = 'jdbc:postgresql:dbserver', dbtable='schema.tablename').load() {% endhighlight %} From ca7e460f7d6fb898dc29236a85520bbe954c8a13 Mon Sep 17 00:00:00 2001 From: nishkamravi2 Date: Tue, 30 Jun 2015 11:12:15 -0700 Subject: [PATCH 091/122] [SPARK-7988] [STREAMING] Round-robin scheduling of receivers by default Minimal PR for round-robin scheduling of receivers. Dense scheduling can be enabled by setting preferredLocation, so a new config parameter isn't really needed. Tested this on a cluster of 6 nodes and noticed 20-25% gain in throughput compared to random scheduling. tdas pwendell Author: nishkamravi2 Author: Nishkam Ravi Closes #6607 from nishkamravi2/master_nravi and squashes the following commits: 1918819 [Nishkam Ravi] Update ReceiverTrackerSuite.scala f747739 [Nishkam Ravi] Update ReceiverTrackerSuite.scala 6127e58 [Nishkam Ravi] Update ReceiverTracker and ReceiverTrackerSuite 9f1abc2 [nishkamravi2] Update ReceiverTrackerSuite.scala ae29152 [Nishkam Ravi] Update test suite with TD's suggestions 48a4a97 [nishkamravi2] Update ReceiverTracker.scala bc23907 [nishkamravi2] Update ReceiverTracker.scala 68e8540 [nishkamravi2] Update SchedulerSuite.scala 4604f28 [nishkamravi2] Update SchedulerSuite.scala 179b90f [nishkamravi2] Update ReceiverTracker.scala 242e677 [nishkamravi2] Update SchedulerSuite.scala 7f3e028 [Nishkam Ravi] Update ReceiverTracker.scala, add unit test cases in SchedulerSuite f8a3e05 [nishkamravi2] Update ReceiverTracker.scala 4cf97b6 [nishkamravi2] Update ReceiverTracker.scala 16e84ec [Nishkam Ravi] Update ReceiverTracker.scala 45e3a99 [Nishkam Ravi] Merge branch 'master_nravi' of https://github.com/nishkamravi2/spark into master_nravi 02dbdb8 [Nishkam Ravi] Update ReceiverTracker.scala 07b9dfa [nishkamravi2] Update ReceiverTracker.scala 6caeefe [nishkamravi2] Update ReceiverTracker.scala 7888257 [nishkamravi2] Update ReceiverTracker.scala 6e3515c [Nishkam Ravi] Minor changes 975b8d8 [Nishkam Ravi] Merge branch 'master_nravi' of https://github.com/nishkamravi2/spark into master_nravi 3cac21b [Nishkam Ravi] Generalize the scheduling algorithm b05ee2f [nishkamravi2] Update ReceiverTracker.scala bb5e09b [Nishkam Ravi] Add a new var in receiver to store location information for round-robin scheduling 41705de [nishkamravi2] Update ReceiverTracker.scala fff1b2e [Nishkam Ravi] Round-robin scheduling of streaming receivers --- .../streaming/scheduler/ReceiverTracker.scala | 64 ++++++++++--- .../scheduler/ReceiverTrackerSuite.scala | 90 +++++++++++++++++++ 2 files changed, 141 insertions(+), 13 deletions(-) create mode 100644 streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index e6cdbec11e94c..644e581cd8279 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -17,8 +17,10 @@ package org.apache.spark.streaming.scheduler -import scala.collection.mutable.{HashMap, SynchronizedMap} +import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedMap} import scala.language.existentials +import scala.math.max +import org.apache.spark.rdd._ import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark.{Logging, SparkEnv, SparkException} @@ -272,6 +274,41 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false } } + /** + * Get the list of executors excluding driver + */ + private def getExecutors(ssc: StreamingContext): List[String] = { + val executors = ssc.sparkContext.getExecutorMemoryStatus.map(_._1.split(":")(0)).toList + val driver = ssc.sparkContext.getConf.get("spark.driver.host") + executors.diff(List(driver)) + } + + /** Set host location(s) for each receiver so as to distribute them over + * executors in a round-robin fashion taking into account preferredLocation if set + */ + private[streaming] def scheduleReceivers(receivers: Seq[Receiver[_]], + executors: List[String]): Array[ArrayBuffer[String]] = { + val locations = new Array[ArrayBuffer[String]](receivers.length) + var i = 0 + for (i <- 0 until receivers.length) { + locations(i) = new ArrayBuffer[String]() + if (receivers(i).preferredLocation.isDefined) { + locations(i) += receivers(i).preferredLocation.get + } + } + var count = 0 + for (i <- 0 until max(receivers.length, executors.length)) { + if (!receivers(i % receivers.length).preferredLocation.isDefined) { + locations(i % receivers.length) += executors(count) + count += 1 + if (count == executors.length) { + count = 0 + } + } + } + locations + } + /** * Get the receivers from the ReceiverInputDStreams, distributes them to the * worker nodes as a parallel collection, and runs them. @@ -283,18 +320,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false rcvr }) - // Right now, we only honor preferences if all receivers have them - val hasLocationPreferences = receivers.map(_.preferredLocation.isDefined).reduce(_ && _) - - // Create the parallel collection of receivers to distributed them on the worker nodes - val tempRDD = - if (hasLocationPreferences) { - val receiversWithPreferences = receivers.map(r => (r, Seq(r.preferredLocation.get))) - ssc.sc.makeRDD[Receiver[_]](receiversWithPreferences) - } else { - ssc.sc.makeRDD(receivers, receivers.size) - } - val checkpointDirOption = Option(ssc.checkpointDir) val serializableHadoopConf = new SerializableConfiguration(ssc.sparkContext.hadoopConfiguration) @@ -311,12 +336,25 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false supervisor.start() supervisor.awaitTermination() } + // Run the dummy Spark job to ensure that all slaves have registered. // This avoids all the receivers to be scheduled on the same node. if (!ssc.sparkContext.isLocal) { ssc.sparkContext.makeRDD(1 to 50, 50).map(x => (x, 1)).reduceByKey(_ + _, 20).collect() } + // Get the list of executors and schedule receivers + val executors = getExecutors(ssc) + val tempRDD = + if (!executors.isEmpty) { + val locations = scheduleReceivers(receivers, executors) + val roundRobinReceivers = (0 until receivers.length).map(i => + (receivers(i), locations(i))) + ssc.sc.makeRDD[Receiver[_]](roundRobinReceivers) + } else { + ssc.sc.makeRDD(receivers, receivers.size) + } + // Distribute the receivers and start them logInfo("Starting " + receivers.length + " receivers") running = true diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala new file mode 100644 index 0000000000000..a6e783861dbe6 --- /dev/null +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/ReceiverTrackerSuite.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.scheduler + +import org.apache.spark.streaming._ +import org.apache.spark.SparkConf +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.receiver._ +import org.apache.spark.util.Utils + +/** Testsuite for receiver scheduling */ +class ReceiverTrackerSuite extends TestSuiteBase { + val sparkConf = new SparkConf().setMaster("local[8]").setAppName("test") + val ssc = new StreamingContext(sparkConf, Milliseconds(100)) + val tracker = new ReceiverTracker(ssc) + val launcher = new tracker.ReceiverLauncher() + val executors: List[String] = List("0", "1", "2", "3") + + test("receiver scheduling - all or none have preferred location") { + + def parse(s: String): Array[Array[String]] = { + val outerSplit = s.split("\\|") + val loc = new Array[Array[String]](outerSplit.length) + var i = 0 + for (i <- 0 until outerSplit.length) { + loc(i) = outerSplit(i).split("\\,") + } + loc + } + + def testScheduler(numReceivers: Int, preferredLocation: Boolean, allocation: String) { + val receivers = + if (preferredLocation) { + Array.tabulate(numReceivers)(i => new DummyReceiver(host = + Some(((i + 1) % executors.length).toString))) + } else { + Array.tabulate(numReceivers)(_ => new DummyReceiver) + } + val locations = launcher.scheduleReceivers(receivers, executors) + val expectedLocations = parse(allocation) + assert(locations.deep === expectedLocations.deep) + } + + testScheduler(numReceivers = 5, preferredLocation = false, allocation = "0|1|2|3|0") + testScheduler(numReceivers = 3, preferredLocation = false, allocation = "0,3|1|2") + testScheduler(numReceivers = 4, preferredLocation = true, allocation = "1|2|3|0") + } + + test("receiver scheduling - some have preferred location") { + val numReceivers = 4; + val receivers: Seq[Receiver[_]] = Seq(new DummyReceiver(host = Some("1")), + new DummyReceiver, new DummyReceiver, new DummyReceiver) + val locations = launcher.scheduleReceivers(receivers, executors) + assert(locations(0)(0) === "1") + assert(locations(1)(0) === "0") + assert(locations(2)(0) === "1") + assert(locations(0).length === 1) + assert(locations(3).length === 1) + } +} + +/** + * Dummy receiver implementation + */ +private class DummyReceiver(host: Option[String] = None) + extends Receiver[Int](StorageLevel.MEMORY_ONLY) { + + def onStart() { + } + + def onStop() { + } + + override def preferredLocation: Option[String] = host +} From 57264400ac7d9f9c59c387c252a9ed8d93fed4fa Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 30 Jun 2015 11:14:38 -0700 Subject: [PATCH 092/122] [SPARK-8630] [STREAMING] Prevent from checkpointing QueueInputDStream This PR throws an exception in `QueueInputDStream.writeObject` so that it can fail the application when calling `StreamingContext.start` rather than failing it during recovering QueueInputDStream. Author: zsxwing Closes #7016 from zsxwing/queueStream-checkpoint and squashes the following commits: 89a3d73 [zsxwing] Fix JavaAPISuite.testQueueStream cc40fd7 [zsxwing] Prevent from checkpointing QueueInputDStream --- .../spark/streaming/StreamingContext.scala | 8 ++++++++ .../api/java/JavaStreamingContext.scala | 18 +++++++++++++++--- .../streaming/dstream/QueueInputDStream.scala | 15 ++++++++++----- .../apache/spark/streaming/JavaAPISuite.java | 8 ++++++++ .../streaming/StreamingContextSuite.scala | 15 +++++++++++++++ 5 files changed, 56 insertions(+), 8 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 1708f309fc002..ec49d0f42d122 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -477,6 +477,10 @@ class StreamingContext private[streaming] ( /** * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. + * + * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @tparam T Type of objects in the RDD @@ -491,6 +495,10 @@ class StreamingContext private[streaming] ( /** * Create an input stream from a queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. + * + * NOTE: Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @param defaultRDD Default RDD is returned by the DStream when the queue is empty. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 989e3a729ebc2..40deb6d7ea79a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -419,7 +419,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: changes to the queue after the stream is created will not be recognized. + * NOTE: + * 1. Changes to the queue after the stream is created will not be recognized. + * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @tparam T Type of objects in the RDD */ @@ -435,7 +439,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: changes to the queue after the stream is created will not be recognized. + * NOTE: + * 1. Changes to the queue after the stream is created will not be recognized. + * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @tparam T Type of objects in the RDD @@ -455,7 +463,11 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { * Create an input stream from an queue of RDDs. In each batch, * it will process either one or all of the RDDs returned by the queue. * - * NOTE: changes to the queue after the stream is created will not be recognized. + * NOTE: + * 1. Changes to the queue after the stream is created will not be recognized. + * 2. Arbitrary RDDs can be added to `queueStream`, there is no way to recover data of + * those RDDs, so `queueStream` doesn't support checkpointing. + * * @param queue Queue of RDDs * @param oneAtATime Whether only one RDD should be consumed from the queue in every interval * @param defaultRDD Default RDD is returned by the DStream when the queue is empty diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index ed7da6dc1315e..a2f5d82a79bd3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -17,13 +17,14 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.UnionRDD -import scala.collection.mutable.Queue -import scala.collection.mutable.ArrayBuffer -import org.apache.spark.streaming.{Time, StreamingContext} +import java.io.{NotSerializableException, ObjectOutputStream} + +import scala.collection.mutable.{ArrayBuffer, Queue} import scala.reflect.ClassTag +import org.apache.spark.rdd.{RDD, UnionRDD} +import org.apache.spark.streaming.{Time, StreamingContext} + private[streaming] class QueueInputDStream[T: ClassTag]( @transient ssc: StreamingContext, @@ -36,6 +37,10 @@ class QueueInputDStream[T: ClassTag]( override def stop() { } + private def writeObject(oos: ObjectOutputStream): Unit = { + throw new NotSerializableException("queueStream doesn't support checkpointing") + } + override def compute(validTime: Time): Option[RDD[T]] = { val buffer = new ArrayBuffer[RDD[T]]() if (oneAtATime && queue.size > 0) { diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 1077b1b2cb7e3..a34f23475804a 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -364,6 +364,14 @@ private void testReduceByWindow(boolean withInverse) { @SuppressWarnings("unchecked") @Test public void testQueueStream() { + ssc.stop(); + // Create a new JavaStreamingContext without checkpointing + SparkConf conf = new SparkConf() + .setMaster("local[2]") + .setAppName("test") + .set("spark.streaming.clock", "org.apache.spark.util.ManualClock"); + ssc = new JavaStreamingContext(conf, new Duration(1000)); + List> expected = Arrays.asList( Arrays.asList(1,2,3), Arrays.asList(4,5,6), diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 819dd2ccfe915..56b4ce5638a51 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -20,6 +20,8 @@ package org.apache.spark.streaming import java.io.{File, NotSerializableException} import java.util.concurrent.atomic.AtomicInteger +import scala.collection.mutable.Queue + import org.apache.commons.io.FileUtils import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts @@ -665,6 +667,19 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo transformed.foreachRDD { rdd => rdd.collect() } } } + test("queueStream doesn't support checkpointing") { + val checkpointDir = Utils.createTempDir() + ssc = new StreamingContext(master, appName, batchDuration) + val rdd = ssc.sparkContext.parallelize(1 to 10) + ssc.queueStream[Int](Queue(rdd)).print() + ssc.checkpoint(checkpointDir.getAbsolutePath) + val e = intercept[NotSerializableException] { + ssc.start() + } + // StreamingContext.validate changes the message, so use "contains" here + assert(e.getMessage.contains("queueStream doesn't support checkpointing")) + } + def addInputStream(s: StreamingContext): DStream[Int] = { val input = (1 to 100).map(i => 1 to i) val inputStream = new TestInputStream(s, input, 1) From d16a9443750eebb7a3d7688d4b98a2ac39cc0da7 Mon Sep 17 00:00:00 2001 From: huangzhaowei Date: Tue, 30 Jun 2015 11:46:22 -0700 Subject: [PATCH 093/122] [SPARK-8619] [STREAMING] Don't recover keytab and principal configuration within Streaming checkpoint [Client.scala](https://github.com/apache/spark/blob/master/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala#L786) will change these configurations, so this would cause the problem that the Streaming recover logic can't find the local keytab file(since configuration was changed) ```scala sparkConf.set("spark.yarn.keytab", keytabFileName) sparkConf.set("spark.yarn.principal", args.principal) ``` Problem described at [Jira](https://issues.apache.org/jira/browse/SPARK-8619) Author: huangzhaowei Closes #7008 from SaintBacchus/SPARK-8619 and squashes the following commits: d50dbdf [huangzhaowei] Delect one blank space 9b8e92c [huangzhaowei] Fix code style and add a short comment. 0d8f800 [huangzhaowei] Don't recover keytab and principal configuration within Streaming checkpoint. --- .../org/apache/spark/streaming/Checkpoint.scala | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index d8dc4e4101664..5279331c9e122 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -44,11 +44,23 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val sparkConfPairs = ssc.conf.getAll def createSparkConf(): SparkConf = { + + // Reload properties for the checkpoint application since user wants to set a reload property + // or spark had changed its value and user wants to set it back. + val propertiesToReload = List( + "spark.master", + "spark.yarn.keytab", + "spark.yarn.principal") + val newSparkConf = new SparkConf(loadDefaults = false).setAll(sparkConfPairs) .remove("spark.driver.host") .remove("spark.driver.port") - val newMasterOption = new SparkConf(loadDefaults = true).getOption("spark.master") - newMasterOption.foreach { newMaster => newSparkConf.setMaster(newMaster) } + val newReloadConf = new SparkConf(loadDefaults = true) + propertiesToReload.foreach { prop => + newReloadConf.getOption(prop).foreach { value => + newSparkConf.set(prop, value) + } + } newSparkConf } From 1e1f339976641af4cc87d4010db57c3b600f91af Mon Sep 17 00:00:00 2001 From: Christian Kadner Date: Tue, 30 Jun 2015 12:22:34 -0700 Subject: [PATCH 094/122] [SPARK-6785] [SQL] fix DateTimeUtils for dates before 1970 Hi Michael, this Pull-Request is a follow-up to [PR-6242](https://github.com/apache/spark/pull/6242). I removed the two obsolete test cases from the HiveQuerySuite and deleted the corresponding golden answer files. Thanks for your review! Author: Christian Kadner Closes #6983 from ckadner/SPARK-6785 and squashes the following commits: ab1e79b [Christian Kadner] Merge remote-tracking branch 'origin/SPARK-6785' into SPARK-6785 1fed877 [Christian Kadner] [SPARK-6785][SQL] failed Scala style test, remove spaces on empty line DateTimeUtils.scala:61 9d8021d [Christian Kadner] [SPARK-6785][SQL] merge recent changes in DateTimeUtils & MiscFunctionsSuite b97c3fb [Christian Kadner] [SPARK-6785][SQL] move test case for DateTimeUtils to DateTimeUtilsSuite a451184 [Christian Kadner] [SPARK-6785][SQL] fix DateTimeUtils.fromJavaDate(java.util.Date) for Dates before 1970 --- .../sql/catalyst/util/DateTimeUtils.scala | 8 ++-- .../catalyst/util/DateTimeUtilsSuite.scala | 40 ++++++++++++++++++- .../sql/ScalaReflectionRelationSuite.scala | 2 +- ...te cast-0-a7cd69b80c77a771a2c955db666be53d | 1 - ... test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 | 1 - .../sql/hive/execution/HiveQuerySuite.scala | 14 ------- .../sql/hive/execution/SQLQuerySuite.scala | 31 +++++++++++++- 7 files changed, 75 insertions(+), 22 deletions(-) delete mode 100644 sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d delete mode 100644 sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 640e67e2ecd76..4269ad5d56737 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -59,10 +59,12 @@ object DateTimeUtils { } } - // we should use the exact day as Int, for example, (year, month, day) -> day - def millisToDays(millisLocal: Long): Int = { - ((millisLocal + threadLocalLocalTimeZone.get().getOffset(millisLocal)) / MILLIS_PER_DAY).toInt + def millisToDays(millisUtc: Long): Int = { + // SPARK-6785: use Math.floor so negative number of days (dates before 1970) + // will correctly work as input for function toJavaDate(Int) + val millisLocal = millisUtc.toDouble + threadLocalLocalTimeZone.get().getOffset(millisUtc) + Math.floor(millisLocal / MILLIS_PER_DAY).toInt } // reverse of millisToDays diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 03eb64f097a37..1d4a60c81efc5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.util -import java.sql.Timestamp +import java.sql.{Date, Timestamp} +import java.text.SimpleDateFormat import org.apache.spark.SparkFunSuite @@ -48,4 +49,41 @@ class DateTimeUtilsSuite extends SparkFunSuite { val t2 = DateTimeUtils.toJavaTimestamp(DateTimeUtils.fromJulianDay(d1, ns1)) assert(t.equals(t2)) } + + test("SPARK-6785: java date conversion before and after epoch") { + def checkFromToJavaDate(d1: Date): Unit = { + val d2 = DateTimeUtils.toJavaDate(DateTimeUtils.fromJavaDate(d1)) + assert(d2.toString === d1.toString) + } + + val df1 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss") + val df2 = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss z") + + checkFromToJavaDate(new Date(100)) + + checkFromToJavaDate(Date.valueOf("1970-01-01")) + + checkFromToJavaDate(new Date(df1.parse("1970-01-01 00:00:00").getTime)) + checkFromToJavaDate(new Date(df2.parse("1970-01-01 00:00:00 UTC").getTime)) + + checkFromToJavaDate(new Date(df1.parse("1970-01-01 00:00:01").getTime)) + checkFromToJavaDate(new Date(df2.parse("1970-01-01 00:00:01 UTC").getTime)) + + checkFromToJavaDate(new Date(df1.parse("1969-12-31 23:59:59").getTime)) + checkFromToJavaDate(new Date(df2.parse("1969-12-31 23:59:59 UTC").getTime)) + + checkFromToJavaDate(Date.valueOf("1969-01-01")) + + checkFromToJavaDate(new Date(df1.parse("1969-01-01 00:00:00").getTime)) + checkFromToJavaDate(new Date(df2.parse("1969-01-01 00:00:00 UTC").getTime)) + + checkFromToJavaDate(new Date(df1.parse("1969-01-01 00:00:01").getTime)) + checkFromToJavaDate(new Date(df2.parse("1969-01-01 00:00:01 UTC").getTime)) + + checkFromToJavaDate(new Date(df1.parse("1989-11-09 11:59:59").getTime)) + checkFromToJavaDate(new Date(df2.parse("1989-11-09 19:59:59 UTC").getTime)) + + checkFromToJavaDate(new Date(df1.parse("1776-07-04 10:30:00").getTime)) + checkFromToJavaDate(new Date(df2.parse("1776-07-04 18:30:00 UTC").getTime)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index 4cb5ba2f0d5eb..ab6d3dd96d271 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -78,7 +78,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, - new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1, 2, 3)) + new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3)) Seq(data).toDF().registerTempTable("reflectData") assert(ctx.sql("SELECT * FROM reflectData").collect().head === diff --git a/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d b/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d deleted file mode 100644 index 98da82fa89386..0000000000000 --- a/sql/hive/src/test/resources/golden/Date cast-0-a7cd69b80c77a771a2c955db666be53d +++ /dev/null @@ -1 +0,0 @@ -1970-01-01 1970-01-01 1969-12-31 16:00:00 1969-12-31 16:00:00 1970-01-01 00:00:00 diff --git a/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 b/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 deleted file mode 100644 index 27ba77ddaf615..0000000000000 --- a/sql/hive/src/test/resources/golden/Date comparison test 2-0-dc1b267f1d79d49e6675afe4fd2a34a5 +++ /dev/null @@ -1 +0,0 @@ -true diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 51dabc67fa7c1..4cdba03b27022 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -324,20 +324,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { | FROM src LIMIT 1 """.stripMargin) - createQueryTest("Date comparison test 2", - "SELECT CAST(CAST(0 AS timestamp) AS date) > CAST(0 AS timestamp) FROM src LIMIT 1") - - createQueryTest("Date cast", - """ - | SELECT - | CAST(CAST(0 AS timestamp) AS date), - | CAST(CAST(CAST(0 AS timestamp) AS date) AS string), - | CAST(0 AS timestamp), - | CAST(CAST(0 AS timestamp) AS string), - | CAST(CAST(CAST('1970-01-01 23:00:00' AS timestamp) AS date) AS timestamp) - | FROM src LIMIT 1 - """.stripMargin) - createQueryTest("Simple Average", "SELECT AVG(key) FROM src") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 9f7e58f890241..6d645393a6da1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.hive.execution +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.DefaultParserDialect import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries import org.apache.spark.sql.catalyst.errors.DialectException -import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ @@ -962,4 +964,31 @@ class SQLQuerySuite extends QueryTest { case None => // OK } } + + test("SPARK-6785: HiveQuerySuite - Date comparison test 2") { + checkAnswer( + sql("SELECT CAST(CAST(0 AS timestamp) AS date) > CAST(0 AS timestamp) FROM src LIMIT 1"), + Row(false)) + } + + test("SPARK-6785: HiveQuerySuite - Date cast") { + // new Date(0) == 1970-01-01 00:00:00.0 GMT == 1969-12-31 16:00:00.0 PST + checkAnswer( + sql( + """ + | SELECT + | CAST(CAST(0 AS timestamp) AS date), + | CAST(CAST(CAST(0 AS timestamp) AS date) AS string), + | CAST(0 AS timestamp), + | CAST(CAST(0 AS timestamp) AS string), + | CAST(CAST(CAST('1970-01-01 23:00:00' AS timestamp) AS date) AS timestamp) + | FROM src LIMIT 1 + """.stripMargin), + Row( + Date.valueOf("1969-12-31"), + String.valueOf("1969-12-31"), + Timestamp.valueOf("1969-12-31 16:00:00"), + String.valueOf("1969-12-31 16:00:00"), + Timestamp.valueOf("1970-01-01 00:00:00"))) + } } From c1befd780c3defc843baa75097de7ec427d3f8ca Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 30 Jun 2015 12:23:48 -0700 Subject: [PATCH 095/122] [SPARK-8664] [ML] Add PCA transformer Add PCA transformer for ML pipeline Author: Yanbo Liang Closes #7065 from yanboliang/spark-8664 and squashes the following commits: 4afae45 [Yanbo Liang] address comments e9effd7 [Yanbo Liang] Add PCA transformer --- .../org/apache/spark/ml/feature/PCA.scala | 130 ++++++++++++++++++ .../org/apache/spark/mllib/feature/PCA.scala | 2 +- .../apache/spark/ml/feature/PCASuite.scala | 64 +++++++++ 3 files changed, 195 insertions(+), 1 deletion(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala new file mode 100644 index 0000000000000..2d3bb680cf309 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.feature +import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * Params for [[PCA]] and [[PCAModel]]. + */ +private[feature] trait PCAParams extends Params with HasInputCol with HasOutputCol { + + /** + * The number of principal components. + * @group param + */ + final val k: IntParam = new IntParam(this, "k", "the number of principal components") + + /** @group getParam */ + def getK: Int = $(k) + +} + +/** + * :: Experimental :: + * PCA trains a model to project vectors to a low-dimensional space using PCA. + */ +@Experimental +class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams { + + def this() = this(Identifiable.randomUID("pca")) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setK(value: Int): this.type = set(k, value) + + /** + * Computes a [[PCAModel]] that contains the principal components of the input vectors. + */ + override def fit(dataset: DataFrame): PCAModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v} + val pca = new feature.PCA(k = $(k)) + val pcaModel = pca.fit(input) + copyValues(new PCAModel(uid, pcaModel).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } + + override def copy(extra: ParamMap): PCA = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model fitted by [[PCA]]. + */ +@Experimental +class PCAModel private[ml] ( + override val uid: String, + pcaModel: feature.PCAModel) + extends Model[PCAModel] with PCAParams { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** + * Transform a vector by computed Principal Components. + * NOTE: Vectors to be transformed must be the same length + * as the source vectors given to [[PCA.fit()]]. + */ + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + val pcaOp = udf { pcaModel.transform _ } + dataset.withColumn($(outputCol), pcaOp(col($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } + + override def copy(extra: ParamMap): PCAModel = { + val copied = new PCAModel(uid, pcaModel) + copyValues(copied, extra) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala index 4e01e402b4283..2a66263d8b7d6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/PCA.scala @@ -68,7 +68,7 @@ class PCA(val k: Int) { * @param k number of principal components. * @param pc a principal components Matrix. Each column is one principal component. */ -class PCAModel private[mllib] (val k: Int, val pc: DenseMatrix) extends VectorTransformer { +class PCAModel private[spark] (val k: Int, val pc: DenseMatrix) extends VectorTransformer { /** * Transform a vector by computed Principal Components. * diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala new file mode 100644 index 0000000000000..d0ae36b28c7a9 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.mllib.linalg.{Vector, Vectors, DenseMatrix, Matrices} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.feature.{PCAModel => OldPCAModel} +import org.apache.spark.sql.Row + +class PCASuite extends SparkFunSuite with MLlibTestSparkContext { + + test("params") { + ParamsSuite.checkParams(new PCA) + val mat = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0)).asInstanceOf[DenseMatrix] + val model = new PCAModel("pca", new OldPCAModel(2, mat)) + ParamsSuite.checkParams(model) + } + + test("pca") { + val data = Array( + Vectors.sparse(5, Seq((1, 1.0), (3, 7.0))), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ) + + val dataRDD = sc.parallelize(data, 2) + + val mat = new RowMatrix(dataRDD) + val pc = mat.computePrincipalComponents(3) + val expected = mat.multiply(pc).rows + + val df = sqlContext.createDataFrame(dataRDD.zip(expected)).toDF("features", "expected") + + val pca = new PCA() + .setInputCol("features") + .setOutputCol("pca_features") + .setK(3) + .fit(df) + + pca.transform(df).select("pca_features", "expected").collect().foreach { + case Row(x: Vector, y: Vector) => + assert(x ~== y absTol 1e-5, "Transformed vector is different with expected vector.") + } + } +} From b8e5bb6fc1553256e950fdad9cb5acc6b296816e Mon Sep 17 00:00:00 2001 From: Vinod K C Date: Tue, 30 Jun 2015 12:24:47 -0700 Subject: [PATCH 096/122] [SPARK-8628] [SQL] Race condition in AbstractSparkSQLParser.parse Made lexical iniatialization as lazy val Author: Vinod K C Closes #7015 from vinodkc/handle_lexical_initialize_schronization and squashes the following commits: b6d1c74 [Vinod K C] Avoided repeated lexical initialization 5863cf7 [Vinod K C] Removed space e27c66c [Vinod K C] Avoid reinitialization of lexical in parse method ef4f60f [Vinod K C] Reverted import order e9fc49a [Vinod K C] handle synchronization in SqlLexical.initialize --- .../apache/spark/sql/catalyst/AbstractSparkSQLParser.scala | 6 ++++-- .../scala/org/apache/spark/sql/catalyst/SqlParser.scala | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala index ef7b3ad9432cf..d494ae7b71d16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst import scala.language.implicitConversions import scala.util.parsing.combinator.lexical.StdLexical import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.{PackratParsers, RegexParsers} +import scala.util.parsing.combinator.PackratParsers import scala.util.parsing.input.CharArrayReader.EofCh import org.apache.spark.sql.catalyst.plans.logical._ @@ -30,12 +30,14 @@ private[sql] abstract class AbstractSparkSQLParser def parse(input: String): LogicalPlan = { // Initialize the Keywords. - lexical.initialize(reservedWords) + initLexical phrase(start)(new lexical.Scanner(input)) match { case Success(plan, _) => plan case failureOrError => sys.error(failureOrError.toString) } } + /* One time initialization of lexical.This avoid reinitialization of lexical in parse method */ + protected lazy val initLexical: Unit = lexical.initialize(reservedWords) protected case class Keyword(str: String) { def normalize: String = lexical.normalizeKeyword(str) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index 79f526e823cd4..8d02fbf4f92c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -40,7 +40,7 @@ class SqlParser extends AbstractSparkSQLParser with DataTypeParser { def parseExpression(input: String): Expression = { // Initialize the Keywords. - lexical.initialize(reservedWords) + initLexical phrase(projection)(new lexical.Scanner(input)) match { case Success(plan, _) => plan case failureOrError => sys.error(failureOrError.toString) From 74cc16dbc35e35fd5cd5542239dcb6e5e7f92d18 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 30 Jun 2015 12:31:33 -0700 Subject: [PATCH 097/122] [SPARK-8471] [ML] Discrete Cosine Transform Feature Transformer Implementation and tests for Discrete Cosine Transformer. Author: Feynman Liang Closes #6894 from feynmanliang/dct-features and squashes the following commits: 433dbc7 [Feynman Liang] Test refactoring 91e9636 [Feynman Liang] Style guide and test helper refactor b5ac19c [Feynman Liang] Use Vector types, add Java test 530983a [Feynman Liang] Tests for other numeric datatypes 195d7aa [Feynman Liang] Implement support for arbitrary numeric types 95d4939 [Feynman Liang] Working DCT for 1D Doubles --- .../feature/DiscreteCosineTransformer.scala | 72 +++++++++++++++++ .../JavaDiscreteCosineTransformerSuite.java | 78 +++++++++++++++++++ .../DiscreteCosineTransformerSuite.scala | 73 +++++++++++++++++ 3 files changed, 223 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/DiscreteCosineTransformer.scala create mode 100644 mllib/src/test/java/org/apache/spark/ml/feature/JavaDiscreteCosineTransformerSuite.java create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/DiscreteCosineTransformerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DiscreteCosineTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DiscreteCosineTransformer.scala new file mode 100644 index 0000000000000..a2f4d59f81c44 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DiscreteCosineTransformer.scala @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import edu.emory.mathcs.jtransforms.dct._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.BooleanParam +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.sql.types.DataType + +/** + * :: Experimental :: + * A feature transformer that takes the 1D discrete cosine transform of a real vector. No zero + * padding is performed on the input vector. + * It returns a real vector of the same length representing the DCT. The return vector is scaled + * such that the transform matrix is unitary (aka scaled DCT-II). + * + * More information on [[https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II Wikipedia]]. + */ +@Experimental +class DiscreteCosineTransformer(override val uid: String) + extends UnaryTransformer[Vector, Vector, DiscreteCosineTransformer] { + + def this() = this(Identifiable.randomUID("dct")) + + /** + * Indicates whether to perform the inverse DCT (true) or forward DCT (false). + * Default: false + * @group param + */ + def inverse: BooleanParam = new BooleanParam( + this, "inverse", "Set transformer to perform inverse DCT") + + /** @group setParam */ + def setInverse(value: Boolean): this.type = set(inverse, value) + + /** @group getParam */ + def getInverse: Boolean = $(inverse) + + setDefault(inverse -> false) + + override protected def createTransformFunc: Vector => Vector = { vec => + val result = vec.toArray + val jTransformer = new DoubleDCT_1D(result.length) + if ($(inverse)) jTransformer.inverse(result, true) else jTransformer.forward(result, true) + Vectors.dense(result) + } + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType.isInstanceOf[VectorUDT], s"Input type must be VectorUDT but got $inputType.") + } + + override protected def outputDataType: DataType = new VectorUDT +} diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDiscreteCosineTransformerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDiscreteCosineTransformerSuite.java new file mode 100644 index 0000000000000..28bc5f65e0532 --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDiscreteCosineTransformerSuite.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature; + +import com.google.common.collect.Lists; +import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.VectorUDT; +import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.types.Metadata; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; + +public class JavaDiscreteCosineTransformerSuite { + private transient JavaSparkContext jsc; + private transient SQLContext jsql; + + @Before + public void setUp() { + jsc = new JavaSparkContext("local", "JavaDiscreteCosineTransformerSuite"); + jsql = new SQLContext(jsc); + } + + @After + public void tearDown() { + jsc.stop(); + jsc = null; + } + + @Test + public void javaCompatibilityTest() { + double[] input = new double[] {1D, 2D, 3D, 4D}; + JavaRDD data = jsc.parallelize(Lists.newArrayList( + RowFactory.create(Vectors.dense(input)) + )); + DataFrame dataset = jsql.createDataFrame(data, new StructType(new StructField[]{ + new StructField("vec", (new VectorUDT()), false, Metadata.empty()) + })); + + double[] expectedResult = input.clone(); + (new DoubleDCT_1D(input.length)).forward(expectedResult, true); + + DiscreteCosineTransformer DCT = new DiscreteCosineTransformer() + .setInputCol("vec") + .setOutputCol("resultVec"); + + Row[] result = DCT.transform(dataset).select("resultVec").collect(); + Vector resultVec = result[0].getAs("resultVec"); + + Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DiscreteCosineTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DiscreteCosineTransformerSuite.scala new file mode 100644 index 0000000000000..ed0fc11f78f69 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DiscreteCosineTransformerSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import scala.beans.BeanInfo + +import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{DataFrame, Row} + +@BeanInfo +case class DCTTestData(vec: Vector, wantedVec: Vector) + +class DiscreteCosineTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("forward transform of discrete cosine matches jTransforms result") { + val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray) + val inverse = false + + testDCT(data, inverse) + } + + test("inverse transform of discrete cosine matches jTransforms result") { + val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray) + val inverse = true + + testDCT(data, inverse) + } + + private def testDCT(data: Vector, inverse: Boolean): Unit = { + val expectedResultBuffer = data.toArray.clone() + if (inverse) { + (new DoubleDCT_1D(data.size)).inverse(expectedResultBuffer, true) + } else { + (new DoubleDCT_1D(data.size)).forward(expectedResultBuffer, true) + } + val expectedResult = Vectors.dense(expectedResultBuffer) + + val dataset = sqlContext.createDataFrame(Seq( + DCTTestData(data, expectedResult) + )) + + val transformer = new DiscreteCosineTransformer() + .setInputCol("vec") + .setOutputCol("resultVec") + .setInverse(inverse) + + transformer.transform(dataset) + .select("resultVec", "wantedVec") + .collect() + .foreach { case Row(resultVec: Vector, wantedVec: Vector) => + assert(Vectors.sqdist(resultVec, wantedVec) < 1e-6) + } + } +} From 61d7b533dd50bfac2162b4edcea94724bbd8fcb1 Mon Sep 17 00:00:00 2001 From: Yuhao Yang Date: Tue, 30 Jun 2015 12:44:43 -0700 Subject: [PATCH 098/122] [SPARK-7514] [MLLIB] Add MinMaxScaler to feature transformation jira: https://issues.apache.org/jira/browse/SPARK-7514 Add a popular scaling method to feature component, which is commonly known as min-max normalization or Rescaling. Core function is, Normalized(x) = (x - min) / (max - min) * scale + newBase where `newBase` and `scale` are parameters (type Double) of the `VectorTransformer`. `newBase` is the new minimum number for the features, and `scale` controls the ranges after transformation. This is a little complicated than the basic MinMax normalization, yet it provides flexibility so that users can control the range more specifically. like [0.1, 0.9] in some NN application. For case that `max == min`, 0.5 is used as the raw value. (0.5 * scale + newBase) I'll add UT once the design got settled ( and this is not considered as too naive) reference: http://en.wikipedia.org/wiki/Feature_scaling http://stn.spotfire.com/spotfire_client_help/index.htm#norm/norm_scale_between_0_and_1.htm Author: Yuhao Yang Closes #6039 from hhbyyh/minMaxNorm and squashes the following commits: f942e9f [Yuhao Yang] add todo for metadata 8b37bbc [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 4894dbc [Yuhao Yang] add copy fa2989f [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 29db415 [Yuhao Yang] add clue and minor adjustment 5b8f7cc [Yuhao Yang] style fix 9b133d0 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 22f20f2 [Yuhao Yang] style change and bug fix 747c9bb [Yuhao Yang] add ut and remove mllib version a5ba0aa [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 585cc07 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 1c6dcb1 [Yuhao Yang] minor change 0f1bc80 [Yuhao Yang] add MinMaxScaler to ml 8e7436e [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 3663165 [Yuhao Yang] Merge remote-tracking branch 'upstream/master' into minMaxNorm 1247c27 [Yuhao Yang] some comments improvement d285a19 [Yuhao Yang] initial checkin for minMaxNorm --- .../spark/ml/feature/MinMaxScaler.scala | 170 ++++++++++++++++++ .../spark/ml/feature/MinMaxScalerSuite.scala | 68 +++++++ 2 files changed, 238 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala new file mode 100644 index 0000000000000..b30adf3df48d2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.annotation.Experimental +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} +import org.apache.spark.ml.param.{ParamMap, DoubleParam, Params} +import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.stat.Statistics +import org.apache.spark.sql._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types.{StructField, StructType} + +/** + * Params for [[MinMaxScaler]] and [[MinMaxScalerModel]]. + */ +private[feature] trait MinMaxScalerParams extends Params with HasInputCol with HasOutputCol { + + /** + * lower bound after transformation, shared by all features + * Default: 0.0 + * @group param + */ + val min: DoubleParam = new DoubleParam(this, "min", + "lower bound of the output feature range") + + /** + * upper bound after transformation, shared by all features + * Default: 1.0 + * @group param + */ + val max: DoubleParam = new DoubleParam(this, "max", + "upper bound of the output feature range") + + /** Validates and transforms the input schema. */ + protected def validateAndTransformSchema(schema: StructType): StructType = { + val inputType = schema($(inputCol)).dataType + require(inputType.isInstanceOf[VectorUDT], + s"Input column ${$(inputCol)} must be a vector column") + require(!schema.fieldNames.contains($(outputCol)), + s"Output column ${$(outputCol)} already exists.") + val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false) + StructType(outputFields) + } + + override def validateParams(): Unit = { + require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})") + } +} + +/** + * :: Experimental :: + * Rescale each feature individually to a common range [min, max] linearly using column summary + * statistics, which is also known as min-max normalization or Rescaling. The rescaled value for + * feature E is calculated as, + * + * Rescaled(e_i) = \frac{e_i - E_{min}}{E_{max} - E_{min}} * (max - min) + min + * + * For the case E_{max} == E_{min}, Rescaled(e_i) = 0.5 * (max + min) + * Note that since zero values will probably be transformed to non-zero values, output of the + * transformer will be DenseVector even for sparse input. + */ +@Experimental +class MinMaxScaler(override val uid: String) + extends Estimator[MinMaxScalerModel] with MinMaxScalerParams { + + def this() = this(Identifiable.randomUID("minMaxScal")) + + setDefault(min -> 0.0, max -> 1.0) + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setMin(value: Double): this.type = set(min, value) + + /** @group setParam */ + def setMax(value: Double): this.type = set(max, value) + + override def fit(dataset: DataFrame): MinMaxScalerModel = { + transformSchema(dataset.schema, logging = true) + val input = dataset.select($(inputCol)).map { case Row(v: Vector) => v } + val summary = Statistics.colStats(input) + copyValues(new MinMaxScalerModel(uid, summary.min, summary.max).setParent(this)) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): MinMaxScaler = defaultCopy(extra) +} + +/** + * :: Experimental :: + * Model fitted by [[MinMaxScaler]]. + * + * TODO: The transformer does not yet set the metadata in the output column (SPARK-8529). + */ +@Experimental +class MinMaxScalerModel private[ml] ( + override val uid: String, + val originalMin: Vector, + val originalMax: Vector) + extends Model[MinMaxScalerModel] with MinMaxScalerParams { + + /** @group setParam */ + def setInputCol(value: String): this.type = set(inputCol, value) + + /** @group setParam */ + def setOutputCol(value: String): this.type = set(outputCol, value) + + /** @group setParam */ + def setMin(value: Double): this.type = set(min, value) + + /** @group setParam */ + def setMax(value: Double): this.type = set(max, value) + + + override def transform(dataset: DataFrame): DataFrame = { + val originalRange = (originalMax.toBreeze - originalMin.toBreeze).toArray + val minArray = originalMin.toArray + + val reScale = udf { (vector: Vector) => + val scale = $(max) - $(min) + + // 0 in sparse vector will probably be rescaled to non-zero + val values = vector.toArray + val size = values.size + var i = 0 + while (i < size) { + val raw = if (originalRange(i) != 0) (values(i) - minArray(i)) / originalRange(i) else 0.5 + values(i) = raw * scale + $(min) + i += 1 + } + Vectors.dense(values) + } + + dataset.withColumn($(outputCol), reScale(col($(inputCol)))) + } + + override def transformSchema(schema: StructType): StructType = { + validateAndTransformSchema(schema) + } + + override def copy(extra: ParamMap): MinMaxScalerModel = { + val copied = new MinMaxScalerModel(uid, originalMin, originalMax) + copyValues(copied, extra) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala new file mode 100644 index 0000000000000..c452054bec92f --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.feature + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{Row, SQLContext} + +class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("MinMaxScaler fit basic case") { + val sqlContext = new SQLContext(sc) + + val data = Array( + Vectors.dense(1, 0, Long.MinValue), + Vectors.dense(2, 0, 0), + Vectors.sparse(3, Array(0, 2), Array(3, Long.MaxValue)), + Vectors.sparse(3, Array(0), Array(1.5))) + + val expected: Array[Vector] = Array( + Vectors.dense(-5, 0, -5), + Vectors.dense(0, 0, 0), + Vectors.sparse(3, Array(0, 2), Array(5, 5)), + Vectors.sparse(3, Array(0), Array(-2.5))) + + val df = sqlContext.createDataFrame(data.zip(expected)).toDF("features", "expected") + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("scaled") + .setMin(-5) + .setMax(5) + + val model = scaler.fit(df) + model.transform(df).select("expected", "scaled").collect() + .foreach { case Row(vector1: Vector, vector2: Vector) => + assert(vector1.equals(vector2), "Transformed vector is different with expected.") + } + } + + test("MinMaxScaler arguments max must be larger than min") { + withClue("arguments max must be larger than min") { + intercept[IllegalArgumentException] { + val scaler = new MinMaxScaler().setMin(10).setMax(0) + scaler.validateParams() + } + intercept[IllegalArgumentException] { + val scaler = new MinMaxScaler().setMin(0).setMax(0) + scaler.validateParams() + } + } + } +} From 79f0b371a36560a009c1b0943c928adc5a1bdd8f Mon Sep 17 00:00:00 2001 From: xutingjun Date: Tue, 30 Jun 2015 13:56:59 -0700 Subject: [PATCH 099/122] [SPARK-8560] [UI] The Executors page will have negative if having resubmitted tasks when the ```taskEnd.reason``` is ```Resubmitted```, it shouldn't do statistics. Because this tasks has a ```SUCCESS``` taskEnd before. Author: xutingjun Closes #6950 from XuTingjun/pageError and squashes the following commits: af35dc3 [xutingjun] When taskEnd is Resubmitted, don't do statistics --- .../org/apache/spark/ui/exec/ExecutorsTab.scala | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 39583af14390d..a88fc4c37d3c9 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -19,7 +19,7 @@ package org.apache.spark.ui.exec import scala.collection.mutable.HashMap -import org.apache.spark.{ExceptionFailure, SparkContext} +import org.apache.spark.{Resubmitted, ExceptionFailure, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.{StorageStatus, StorageStatusListener} @@ -92,15 +92,22 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp val info = taskEnd.taskInfo if (info != null) { val eid = info.executorId - executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 1) - 1 - executorToDuration(eid) = executorToDuration.getOrElse(eid, 0L) + info.duration taskEnd.reason match { + case Resubmitted => + // Note: For resubmitted tasks, we continue to use the metrics that belong to the + // first attempt of this task. This may not be 100% accurate because the first attempt + // could have failed half-way through. The correct fix would be to keep track of the + // metrics added by each attempt, but this is much more complicated. + return case e: ExceptionFailure => executorToTasksFailed(eid) = executorToTasksFailed.getOrElse(eid, 0) + 1 case _ => executorToTasksComplete(eid) = executorToTasksComplete.getOrElse(eid, 0) + 1 } + executorToTasksActive(eid) = executorToTasksActive.getOrElse(eid, 1) - 1 + executorToDuration(eid) = executorToDuration.getOrElse(eid, 0L) + info.duration + // Update shuffle read/write val metrics = taskEnd.taskMetrics if (metrics != null) { From 7dda0844e1eb6df7455af68592751806b3b92251 Mon Sep 17 00:00:00 2001 From: Joshi Date: Tue, 30 Jun 2015 14:00:35 -0700 Subject: [PATCH 100/122] [SPARK-2645] [CORE] Allow SparkEnv.stop() to be called multiple times without side effects. Fix for SparkContext stop behavior - Allow sc.stop() to be called multiple times without side effects. Author: Joshi Author: Rekha Joshi Closes #6973 from rekhajoshm/SPARK-2645 and squashes the following commits: 277043e [Joshi] Fix for SparkContext stop behavior 446b0a4 [Joshi] Fix for SparkContext stop behavior 2ce5760 [Joshi] Fix for SparkContext stop behavior c97839a [Joshi] Fix for SparkContext stop behavior 1aff39c [Joshi] Fix for SparkContext stop behavior 12f66b5 [Joshi] Fix for SparkContext stop behavior 72bb484 [Joshi] Fix for SparkContext stop behavior a5a7d7f [Joshi] Fix for SparkContext stop behavior 9193a0c [Joshi] Fix for SparkContext stop behavior 58dba70 [Joshi] SPARK-2645: Fix for SparkContext stop behavior 380c5b0 [Joshi] SPARK-2645: Fix for SparkContext stop behavior b566b66 [Joshi] SPARK-2645: Fix for SparkContext stop behavior 0be142d [Rekha Joshi] Merge pull request #3 from apache/master 106fd8e [Rekha Joshi] Merge pull request #2 from apache/master e3677c9 [Rekha Joshi] Merge pull request #1 from apache/master --- .../scala/org/apache/spark/SparkEnv.scala | 66 ++++++++++--------- .../org/apache/spark/SparkContextSuite.scala | 13 ++++ 2 files changed, 47 insertions(+), 32 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index b0665570e2681..1b133fbdfaf59 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -22,7 +22,6 @@ import java.net.Socket import akka.actor.ActorSystem -import scala.collection.JavaConversions._ import scala.collection.mutable import scala.util.Properties @@ -90,39 +89,42 @@ class SparkEnv ( private var driverTmpDirToDelete: Option[String] = None private[spark] def stop() { - isStopped = true - pythonWorkers.foreach { case(key, worker) => worker.stop() } - Option(httpFileServer).foreach(_.stop()) - mapOutputTracker.stop() - shuffleManager.stop() - broadcastManager.stop() - blockManager.stop() - blockManager.master.stop() - metricsSystem.stop() - outputCommitCoordinator.stop() - rpcEnv.shutdown() - - // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut - // down, but let's call it anyway in case it gets fixed in a later release - // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. - // actorSystem.awaitTermination() - - // Note that blockTransferService is stopped by BlockManager since it is started by it. - - // If we only stop sc, but the driver process still run as a services then we need to delete - // the tmp dir, if not, it will create too many tmp dirs. - // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the - // current working dir in executor which we do not need to delete. - driverTmpDirToDelete match { - case Some(path) => { - try { - Utils.deleteRecursively(new File(path)) - } catch { - case e: Exception => - logWarning(s"Exception while deleting Spark temp dir: $path", e) + + if (!isStopped) { + isStopped = true + pythonWorkers.values.foreach(_.stop()) + Option(httpFileServer).foreach(_.stop()) + mapOutputTracker.stop() + shuffleManager.stop() + broadcastManager.stop() + blockManager.stop() + blockManager.master.stop() + metricsSystem.stop() + outputCommitCoordinator.stop() + rpcEnv.shutdown() + + // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut + // down, but let's call it anyway in case it gets fixed in a later release + // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. + // actorSystem.awaitTermination() + + // Note that blockTransferService is stopped by BlockManager since it is started by it. + + // If we only stop sc, but the driver process still run as a services then we need to delete + // the tmp dir, if not, it will create too many tmp dirs. + // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the + // current working dir in executor which we do not need to delete. + driverTmpDirToDelete match { + case Some(path) => { + try { + Utils.deleteRecursively(new File(path)) + } catch { + case e: Exception => + logWarning(s"Exception while deleting Spark temp dir: $path", e) + } } + case None => // We just need to delete tmp dir created by driver, so do nothing on executor } - case None => // We just need to delete tmp dir created by driver, so do nothing on executor } } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 6838b35ab4cc8..5c57940fa5f77 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.util.Utils import scala.concurrent.Await import scala.concurrent.duration.Duration +import org.scalatest.Matchers._ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { @@ -272,4 +273,16 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { sc.stop() } } + + test("calling multiple sc.stop() must not throw any exception") { + noException should be thrownBy { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + val cnt = sc.parallelize(1 to 4).count() + sc.cancelAllJobs() + sc.stop() + // call stop second time + sc.stop() + } + } + } From 4bb8375fc2c6aa8342df03c3617aa97e7d01de3f Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 30 Jun 2015 14:01:52 -0700 Subject: [PATCH 101/122] [SPARK-8372] Do not show applications that haven't recorded their app ID yet. Showing these applications may lead to weird behavior in the History Server. For old logs, if the app ID is recorded later, you may end up with a duplicate entry. For new logs, the app might be listed with a ".inprogress" suffix. So ignore those, but still allow old applications that don't record app IDs at all (1.0 and 1.1) to be shown. Author: Marcelo Vanzin Author: Carson Wang Closes #7097 from vanzin/SPARK-8372 and squashes the following commits: a24eab2 [Marcelo Vanzin] Feedback. 112ae8f [Marcelo Vanzin] Merge branch 'master' into SPARK-8372 7b91b74 [Marcelo Vanzin] Handle logs generated by 1.0 and 1.1. 1eca3fe [Carson Wang] [SPARK-8372] History server shows incorrect information for application not started --- .../deploy/history/FsHistoryProvider.scala | 98 ++++++++++------ .../history/FsHistoryProviderSuite.scala | 109 +++++++++++++----- 2 files changed, 147 insertions(+), 60 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 5427a88f32ffd..2cc465e55fceb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -83,12 +83,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // List of application logs to be deleted by event log cleaner. private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] - // Constants used to parse Spark 1.0.0 log directories. - private[history] val LOG_PREFIX = "EVENT_LOG_" - private[history] val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" - private[history] val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_" - private[history] val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" - /** * Return a runnable that performs the given operation on the event logs. * This operation is expected to be executed periodically. @@ -146,7 +140,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) override def getAppUI(appId: String, attemptId: Option[String]): Option[SparkUI] = { try { applications.get(appId).flatMap { appInfo => - appInfo.attempts.find(_.attemptId == attemptId).map { attempt => + appInfo.attempts.find(_.attemptId == attemptId).flatMap { attempt => val replayBus = new ReplayListenerBus() val ui = { val conf = this.conf.clone() @@ -155,20 +149,20 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) HistoryServer.getAttemptURI(appId, attempt.attemptId), attempt.startTime) // Do not call ui.bind() to avoid creating a new server for each application } - val appListener = new ApplicationEventListener() replayBus.addListener(appListener) val appInfo = replay(fs.getFileStatus(new Path(logDir, attempt.logPath)), replayBus) - - ui.setAppName(s"${appInfo.name} ($appId)") - - val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) - ui.getSecurityManager.setAcls(uiAclsEnabled) - // make sure to set admin acls before view acls so they are properly picked up - ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) - ui.getSecurityManager.setViewAcls(attempt.sparkUser, - appListener.viewAcls.getOrElse("")) - ui + appInfo.map { info => + ui.setAppName(s"${info.name} ($appId)") + + val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false) + ui.getSecurityManager.setAcls(uiAclsEnabled) + // make sure to set admin acls before view acls so they are properly picked up + ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse("")) + ui.getSecurityManager.setViewAcls(attempt.sparkUser, + appListener.viewAcls.getOrElse("")) + ui + } } } } catch { @@ -282,8 +276,12 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val newAttempts = logs.flatMap { fileStatus => try { val res = replay(fileStatus, bus) - logInfo(s"Application log ${res.logPath} loaded successfully.") - Some(res) + res match { + case Some(r) => logDebug(s"Application log ${r.logPath} loaded successfully.") + case None => logWarning(s"Failed to load application log ${fileStatus.getPath}. " + + "The application may have not started.") + } + res } catch { case e: Exception => logError( @@ -429,9 +427,11 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replays the events in the specified log file and returns information about the associated - * application. + * application. Return `None` if the application ID cannot be located. */ - private def replay(eventLog: FileStatus, bus: ReplayListenerBus): FsApplicationAttemptInfo = { + private def replay( + eventLog: FileStatus, + bus: ReplayListenerBus): Option[FsApplicationAttemptInfo] = { val logPath = eventLog.getPath() logInfo(s"Replaying log path: $logPath") val logInput = @@ -445,16 +445,24 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) val appCompleted = isApplicationCompleted(eventLog) bus.addListener(appListener) bus.replay(logInput, logPath.toString, !appCompleted) - new FsApplicationAttemptInfo( - logPath.getName(), - appListener.appName.getOrElse(NOT_STARTED), - appListener.appId.getOrElse(logPath.getName()), - appListener.appAttemptId, - appListener.startTime.getOrElse(-1L), - appListener.endTime.getOrElse(-1L), - getModificationTime(eventLog).get, - appListener.sparkUser.getOrElse(NOT_STARTED), - appCompleted) + + // Without an app ID, new logs will render incorrectly in the listing page, so do not list or + // try to show their UI. Some old versions of Spark generate logs without an app ID, so let + // logs generated by those versions go through. + if (appListener.appId.isDefined || !sparkVersionHasAppId(eventLog)) { + Some(new FsApplicationAttemptInfo( + logPath.getName(), + appListener.appName.getOrElse(NOT_STARTED), + appListener.appId.getOrElse(logPath.getName()), + appListener.appAttemptId, + appListener.startTime.getOrElse(-1L), + appListener.endTime.getOrElse(-1L), + getModificationTime(eventLog).get, + appListener.sparkUser.getOrElse(NOT_STARTED), + appCompleted)) + } else { + None + } } finally { logInput.close() } @@ -529,10 +537,34 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } + /** + * Returns whether the version of Spark that generated logs records app IDs. App IDs were added + * in Spark 1.1. + */ + private def sparkVersionHasAppId(entry: FileStatus): Boolean = { + if (isLegacyLogDirectory(entry)) { + fs.listStatus(entry.getPath()) + .find { status => status.getPath().getName().startsWith(SPARK_VERSION_PREFIX) } + .map { status => + val version = status.getPath().getName().substring(SPARK_VERSION_PREFIX.length()) + version != "1.0" && version != "1.1" + } + .getOrElse(true) + } else { + true + } + } + } -private object FsHistoryProvider { +private[history] object FsHistoryProvider { val DEFAULT_LOG_DIR = "file:/tmp/spark-events" + + // Constants used to parse Spark 1.0.0 log directories. + val LOG_PREFIX = "EVENT_LOG_" + val SPARK_VERSION_PREFIX = EventLoggingListener.SPARK_VERSION_KEY + "_" + val COMPRESSION_CODEC_PREFIX = EventLoggingListener.COMPRESSION_CODEC_KEY + "_" + val APPLICATION_COMPLETE = "APPLICATION_COMPLETE" } private class FsApplicationAttemptInfo( diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 09075eeb539aa..2a62450bcdbad 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -39,6 +39,8 @@ import org.apache.spark.util.{JsonProtocol, ManualClock, Utils} class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { + import FsHistoryProvider._ + private var testDir: File = null before { @@ -67,7 +69,8 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc // Write a new-style application log. val newAppComplete = newLogFile("new1", None, inProgress = false) writeFile(newAppComplete, true, None, - SparkListenerApplicationStart("new-app-complete", None, 1L, "test", None), + SparkListenerApplicationStart(newAppComplete.getName(), Some("new-app-complete"), 1L, "test", + None), SparkListenerApplicationEnd(5L) ) @@ -75,35 +78,30 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val newAppCompressedComplete = newLogFile("new1compressed", None, inProgress = false, Some("lzf")) writeFile(newAppCompressedComplete, true, None, - SparkListenerApplicationStart("new-app-compressed-complete", None, 1L, "test", None), + SparkListenerApplicationStart(newAppCompressedComplete.getName(), Some("new-complete-lzf"), + 1L, "test", None), SparkListenerApplicationEnd(4L)) // Write an unfinished app, new-style. val newAppIncomplete = newLogFile("new2", None, inProgress = true) writeFile(newAppIncomplete, true, None, - SparkListenerApplicationStart("new-app-incomplete", None, 1L, "test", None) + SparkListenerApplicationStart(newAppIncomplete.getName(), Some("new-incomplete"), 1L, "test", + None) ) // Write an old-style application log. - val oldAppComplete = new File(testDir, "old1") - oldAppComplete.mkdir() - createEmptyFile(new File(oldAppComplete, provider.SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(oldAppComplete, provider.LOG_PREFIX + "1"), false, None, - SparkListenerApplicationStart("old-app-complete", None, 2L, "test", None), + val oldAppComplete = writeOldLog("old1", "1.0", None, true, + SparkListenerApplicationStart("old1", Some("old-app-complete"), 2L, "test", None), SparkListenerApplicationEnd(3L) ) - createEmptyFile(new File(oldAppComplete, provider.APPLICATION_COMPLETE)) // Check for logs so that we force the older unfinished app to be loaded, to make // sure unfinished apps are also sorted correctly. provider.checkForLogs() // Write an unfinished app, old-style. - val oldAppIncomplete = new File(testDir, "old2") - oldAppIncomplete.mkdir() - createEmptyFile(new File(oldAppIncomplete, provider.SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(oldAppIncomplete, provider.LOG_PREFIX + "1"), false, None, - SparkListenerApplicationStart("old-app-incomplete", None, 2L, "test", None) + val oldAppIncomplete = writeOldLog("old2", "1.0", None, false, + SparkListenerApplicationStart("old2", None, 2L, "test", None) ) // Force a reload of data from the log directory, and check that both logs are loaded. @@ -124,16 +122,15 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc List(ApplicationAttemptInfo(None, start, end, lastMod, user, completed))) } - list(0) should be (makeAppInfo(newAppComplete.getName(), "new-app-complete", 1L, 5L, + list(0) should be (makeAppInfo("new-app-complete", newAppComplete.getName(), 1L, 5L, newAppComplete.lastModified(), "test", true)) - list(1) should be (makeAppInfo(newAppCompressedComplete.getName(), - "new-app-compressed-complete", 1L, 4L, newAppCompressedComplete.lastModified(), "test", - true)) - list(2) should be (makeAppInfo(oldAppComplete.getName(), "old-app-complete", 2L, 3L, + list(1) should be (makeAppInfo("new-complete-lzf", newAppCompressedComplete.getName(), + 1L, 4L, newAppCompressedComplete.lastModified(), "test", true)) + list(2) should be (makeAppInfo("old-app-complete", oldAppComplete.getName(), 2L, 3L, oldAppComplete.lastModified(), "test", true)) - list(3) should be (makeAppInfo(oldAppIncomplete.getName(), "old-app-incomplete", 2L, -1L, - oldAppIncomplete.lastModified(), "test", false)) - list(4) should be (makeAppInfo(newAppIncomplete.getName(), "new-app-incomplete", 1L, -1L, + list(3) should be (makeAppInfo(oldAppIncomplete.getName(), oldAppIncomplete.getName(), 2L, + -1L, oldAppIncomplete.lastModified(), "test", false)) + list(4) should be (makeAppInfo("new-incomplete", newAppIncomplete.getName(), 1L, -1L, newAppIncomplete.lastModified(), "test", false)) // Make sure the UI can be rendered. @@ -155,12 +152,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc val codec = if (valid) CompressionCodec.createCodec(new SparkConf(), codecName) else null val logDir = new File(testDir, codecName) logDir.mkdir() - createEmptyFile(new File(logDir, provider.SPARK_VERSION_PREFIX + "1.0")) - writeFile(new File(logDir, provider.LOG_PREFIX + "1"), false, Option(codec), + createEmptyFile(new File(logDir, SPARK_VERSION_PREFIX + "1.0")) + writeFile(new File(logDir, LOG_PREFIX + "1"), false, Option(codec), SparkListenerApplicationStart("app2", None, 2L, "test", None), SparkListenerApplicationEnd(3L) ) - createEmptyFile(new File(logDir, provider.COMPRESSION_CODEC_PREFIX + codecName)) + createEmptyFile(new File(logDir, COMPRESSION_CODEC_PREFIX + codecName)) val logPath = new Path(logDir.getAbsolutePath()) try { @@ -180,12 +177,12 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc test("SPARK-3697: ignore directories that cannot be read.") { val logFile1 = newLogFile("new1", None, inProgress = false) writeFile(logFile1, true, None, - SparkListenerApplicationStart("app1-1", None, 1L, "test", None), + SparkListenerApplicationStart("app1-1", Some("app1-1"), 1L, "test", None), SparkListenerApplicationEnd(2L) ) val logFile2 = newLogFile("new2", None, inProgress = false) writeFile(logFile2, true, None, - SparkListenerApplicationStart("app1-2", None, 1L, "test", None), + SparkListenerApplicationStart("app1-2", Some("app1-2"), 1L, "test", None), SparkListenerApplicationEnd(2L) ) logFile2.setReadable(false, false) @@ -218,6 +215,18 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + test("Parse logs that application is not started") { + val provider = new FsHistoryProvider((createTestConf())) + + val logFile1 = newLogFile("app1", None, inProgress = true) + writeFile(logFile1, true, None, + SparkListenerLogStart("1.4") + ) + updateAndCheck(provider) { list => + list.size should be (0) + } + } + test("SPARK-5582: empty log directory") { val provider = new FsHistoryProvider(createTestConf()) @@ -373,6 +382,33 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } + test("SPARK-8372: new logs with no app ID are ignored") { + val provider = new FsHistoryProvider(createTestConf()) + + // Write a new log file without an app id, to make sure it's ignored. + val logFile1 = newLogFile("app1", None, inProgress = true) + writeFile(logFile1, true, None, + SparkListenerLogStart("1.4") + ) + + // Write a 1.2 log file with no start event (= no app id), it should be ignored. + writeOldLog("v12Log", "1.2", None, false) + + // Write 1.0 and 1.1 logs, which don't have app ids. + writeOldLog("v11Log", "1.1", None, true, + SparkListenerApplicationStart("v11Log", None, 2L, "test", None), + SparkListenerApplicationEnd(3L)) + writeOldLog("v10Log", "1.0", None, true, + SparkListenerApplicationStart("v10Log", None, 2L, "test", None), + SparkListenerApplicationEnd(4L)) + + updateAndCheck(provider) { list => + list.size should be (2) + list(0).id should be ("v10Log") + list(1).id should be ("v11Log") + } + } + /** * Asks the provider to check for logs and calls a function to perform checks on the updated * app list. Example: @@ -412,4 +448,23 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc new SparkConf().set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) } + private def writeOldLog( + fname: String, + sparkVersion: String, + codec: Option[CompressionCodec], + completed: Boolean, + events: SparkListenerEvent*): File = { + val log = new File(testDir, fname) + log.mkdir() + + val oldEventLog = new File(log, LOG_PREFIX + "1") + createEmptyFile(new File(log, SPARK_VERSION_PREFIX + sparkVersion)) + writeFile(new File(log, LOG_PREFIX + "1"), false, codec, events: _*) + if (completed) { + createEmptyFile(new File(log, APPLICATION_COMPLETE)) + } + + log + } + } From 3ba23ffd377d12383d923d1550ac8e2b916090fc Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 30 Jun 2015 14:02:50 -0700 Subject: [PATCH 102/122] [SPARK-8736] [ML] GBTRegressor should not threshold prediction Changed GBTRegressor so it does NOT threshold the prediction. Added test which fails with bug but works after fix. CC: feynmanliang mengxr Author: Joseph K. Bradley Closes #7134 from jkbradley/gbrt-fix and squashes the following commits: 613b90e [Joseph K. Bradley] Changed GBTRegressor so it does NOT threshold the prediction --- .../spark/ml/regression/GBTRegressor.scala | 3 +-- .../ml/regression/GBTRegressorSuite.scala | 23 ++++++++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 036e3acb07412..47c110d027d67 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -172,8 +172,7 @@ final class GBTRegressionModel( // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions val treePredictions = _trees.map(_.rootNode.predict(features)) - val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) - if (prediction > 0.0) 1.0 else 0.0 + blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) } override def copy(extra: ParamMap): GBTRegressionModel = { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index 98fb3d3f5f22c..9682edcd9ba84 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => OldGBT} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Row} /** @@ -67,6 +68,26 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("GBTRegressor behaves reasonably on toy data") { + val df = sqlContext.createDataFrame(Seq( + LabeledPoint(10, Vectors.dense(1, 2, 3, 4)), + LabeledPoint(-5, Vectors.dense(6, 3, 2, 1)), + LabeledPoint(11, Vectors.dense(2, 2, 3, 4)), + LabeledPoint(-6, Vectors.dense(6, 4, 2, 1)), + LabeledPoint(9, Vectors.dense(1, 2, 6, 4)), + LabeledPoint(-4, Vectors.dense(6, 3, 2, 2)) + )) + val gbt = new GBTRegressor() + .setMaxDepth(2) + .setMaxIter(2) + val model = gbt.fit(df) + val preds = model.transform(df) + val predictions = preds.select("prediction").map(_.getDouble(0)) + // Checks based on SPARK-8736 (to ensure it is not doing classification) + assert(predictions.max() > 2) + assert(predictions.min() < -1) + } + // TODO: Reinstate test once runWithValidation is implemented SPARK-7132 /* test("runWithValidation stops early and performs better on a validation dataset") { From 8c898964f095fcb5bb1c9212e1e484b1eb55c296 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 30 Jun 2015 14:06:50 -0700 Subject: [PATCH 103/122] [SPARK-8705] [WEBUI] Don't display rects when totalExecutionTime is 0 Because `System.currentTimeMillis()` is not accurate for tasks that only need several milliseconds, sometimes `totalExecutionTime` in `makeTimeline` will be 0. If `totalExecutionTime` is 0, there will the following error in the console. ![screen shot 2015-06-29 at 7 08 55 pm](https://cloud.githubusercontent.com/assets/1000778/8406776/5cd38e04-1e92-11e5-89f2-0c5134fe4b6b.png) This PR fixes it by using an empty svg tag when `totalExecutionTime` is 0. This is a screenshot for a task that its totalExecutionTime is 0 after fixing it. ![screen shot 2015-06-30 at 12 26 52 am](https://cloud.githubusercontent.com/assets/1000778/8412896/7b33b4be-1ebf-11e5-9100-d6d656af3747.png) Author: zsxwing Closes #7088 from zsxwing/SPARK-8705 and squashes the following commits: 9ee4ef5 [zsxwing] Address comments ef2ecfa [zsxwing] Don't display rects when totalExecutionTime is 0 --- .../org/apache/spark/ui/jobs/StagePage.scala | 52 +++++++++++-------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index e96bf49d0dd14..17e7519ddd01c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -570,6 +570,35 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val index = taskInfo.index val attempt = taskInfo.attempt + + val svgTag = + if (totalExecutionTime == 0) { + // SPARK-8705: Avoid invalid attribute error in JavaScript if execution time is 0 + """""" + } else { + s""" + | + | + | + | + | + | + |""".stripMargin + } val timelineObject = s""" |{ @@ -595,28 +624,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { |
Shuffle Write Time: ${UIUtils.formatDuration(shuffleWriteTime)} |
Result Serialization Time: ${UIUtils.formatDuration(serializationTime)} |
Getting Result Time: ${UIUtils.formatDuration(gettingResultTime)}"> - | - | - | - | - | - | - | - |', + |$svgTag', |'start': new Date($launchTime), |'end': new Date($finishTime) |} From e72526227fdcf93b7a33375ef954746ac08753f5 Mon Sep 17 00:00:00 2001 From: lee19 Date: Tue, 30 Jun 2015 14:08:00 -0700 Subject: [PATCH 104/122] [SPARK-8563] [MLLIB] Fixed a bug so that IndexedRowMatrix.computeSVD().U.numCols = k I'm sorry that I made https://github.com/apache/spark/pull/6949 closed by mistake. I pushed codes again. And, I added a test code. > There is a bug that `U.numCols() = self.nCols` in `IndexedRowMatrix.computeSVD()` It should have been `U.numCols() = k = svd.U.numCols()` > ``` self = U * sigma * V.transpose (m x n) = (m x n) * (k x k) * (k x n) //ASIS --> (m x n) = (m x k) * (k x k) * (k x n) //TOBE ``` Author: lee19 Closes #6953 from lee19/MLlibBugfix and squashes the following commits: c1812a0 [lee19] [SPARK-8563] [MLlib] Used nRows instead of numRows() to reduce a burden. 4b9803b [lee19] [SPARK-8563] [MLlib] Fixed a build error. c2ccd89 [lee19] Added a unit test that validates matrix sizes of svd for [SPARK-8563][MLlib] 8373424 [lee19] [SPARK-8563][MLlib] Fixed a bug so that IndexedRowMatrix.computeSVD().U.numCols = k --- .../mllib/linalg/distributed/IndexedRowMatrix.scala | 2 +- .../linalg/distributed/IndexedRowMatrixSuite.scala | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 3be530fa07537..1c33b43ea7a8a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -146,7 +146,7 @@ class IndexedRowMatrix( val indexedRows = indices.zip(svd.U.rows).map { case (i, v) => IndexedRow(i, v) } - new IndexedRowMatrix(indexedRows, nRows, nCols) + new IndexedRowMatrix(indexedRows, nRows, svd.U.numCols().toInt) } else { null } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 4a7b99a976f0a..0ecb7a221a503 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -135,6 +135,17 @@ class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { assert(closeToZero(U * brzDiag(s) * V.t - localA)) } + test("validate matrix sizes of svd") { + val k = 2 + val A = new IndexedRowMatrix(indexedRows) + val svd = A.computeSVD(k, computeU = true) + assert(svd.U.numRows() === m) + assert(svd.U.numCols() === k) + assert(svd.s.size === k) + assert(svd.V.numRows === n) + assert(svd.V.numCols === k) + } + test("validate k in svd") { val A = new IndexedRowMatrix(indexedRows) intercept[IllegalArgumentException] { From d2495f7cc7d7caaa50d122d2969ddb693e6ecebd Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Tue, 30 Jun 2015 14:09:29 -0700 Subject: [PATCH 105/122] [SPARK-8739] [WEB UI] [WINDOWS] A illegal character `\r` can be contained in StagePage. This issue was reported by saurfang. Thanks! There is a following code in StagePage.scala. ``` |width="$serializationTimeProportion%"> |', |'start': new Date($launchTime), |'end': new Date($finishTime) |} |""".stripMargin.replaceAll("\n", " ") ``` The last `replaceAll("\n", "")` doesn't work when we checkout and build source code on Windows and deploy on Linux. It's because when we checkout the source code on Windows, new-line-code is replaced with `"\r\n"` and `replaceAll("\n", "")` replaces only `"\n"`. Author: Kousuke Saruta Closes #7133 from sarutak/SPARK-8739 and squashes the following commits: 17fb044 [Kousuke Saruta] Fixed a new-line-code issue --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 17e7519ddd01c..60e3c6343122c 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -628,7 +628,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { |'start': new Date($launchTime), |'end': new Date($finishTime) |} - |""".stripMargin.replaceAll("\n", " ") + |""".stripMargin.replaceAll("""[\r\n]+""", " ") timelineObject }.mkString("[", ",", "]") From 58ee2a2e47948a895e557fbcabbeadb31f0a1022 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 30 Jun 2015 16:17:46 -0700 Subject: [PATCH 106/122] [SPARK-8738] [SQL] [PYSPARK] capture SQL AnalysisException in Python API Capture the AnalysisException in SQL, hide the long java stack trace, only show the error message. cc rxin Author: Davies Liu Closes #7135 from davies/ananylis and squashes the following commits: dad7ae7 [Davies Liu] add comment ec0c0e8 [Davies Liu] Update utils.py cdd7edd [Davies Liu] add doc 7b044c2 [Davies Liu] fix python 3 f84d3bd [Davies Liu] capture SQL AnalysisException in Python API --- python/pyspark/rdd.py | 3 +- python/pyspark/sql/context.py | 2 ++ python/pyspark/sql/tests.py | 7 +++++ python/pyspark/sql/utils.py | 54 +++++++++++++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 python/pyspark/sql/utils.py diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index cb20bc8b54027..79dafb0a4ef27 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -126,11 +126,12 @@ def _load_from_socket(port, serializer): # On most of IPv6-ready systems, IPv6 will take precedence. for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): af, socktype, proto, canonname, sa = res + sock = socket.socket(af, socktype, proto) try: - sock = socket.socket(af, socktype, proto) sock.settimeout(3) sock.connect(sa) except socket.error: + sock.close() sock = None continue break diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 4dda3b430cfbf..4bf232111c496 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -33,6 +33,7 @@ _infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader +from pyspark.sql.utils import install_exception_handler try: import pandas @@ -96,6 +97,7 @@ def __init__(self, sparkContext, sqlContext=None): self._jvm = self._sc._jvm self._scala_SQLContext = sqlContext _monkey_patch_RDD(self) + install_exception_handler() @property def _ssql_ctx(self): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 34f397d0ffef0..5af2ce09bc122 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -46,6 +46,7 @@ from pyspark.tests import ReusedPySparkTestCase from pyspark.sql.functions import UserDefinedFunction from pyspark.sql.window import Window +from pyspark.sql.utils import AnalysisException class UTC(datetime.tzinfo): @@ -847,6 +848,12 @@ def test_replace(self): self.assertEqual(row.age, 10) self.assertEqual(row.height, None) + def test_capture_analysis_exception(self): + self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc")) + self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) + # RuntimeException should not be captured + self.assertRaises(py4j.protocol.Py4JJavaError, lambda: self.sqlCtx.sql("abc")) + class HiveContextSQLTests(ReusedPySparkTestCase): diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py new file mode 100644 index 0000000000000..8096802e7302f --- /dev/null +++ b/python/pyspark/sql/utils.py @@ -0,0 +1,54 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import py4j + + +class AnalysisException(Exception): + """ + Failed to analyze a SQL query plan. + """ + + +def capture_sql_exception(f): + def deco(*a, **kw): + try: + return f(*a, **kw) + except py4j.protocol.Py4JJavaError as e: + cls, msg = e.java_exception.toString().split(': ', 1) + if cls == 'org.apache.spark.sql.AnalysisException': + raise AnalysisException(msg) + raise + return deco + + +def install_exception_handler(): + """ + Hook an exception handler into Py4j, which could capture some SQL exceptions in Java. + + When calling Java API, it will call `get_return_value` to parse the returned object. + If any exception happened in JVM, the result will be Java exception object, it raise + py4j.protocol.Py4JJavaError. We replace the original `get_return_value` with one that + could capture the Java exception and throw a Python one (with the same error message). + + It's idempotent, could be called multiple times. + """ + original = py4j.protocol.get_return_value + # The original `get_return_value` is not patched, it's idempotent. + patched = capture_sql_exception(original) + # only patch the one used in in py4j.java_gateway (call Java API) + py4j.java_gateway.get_return_value = patched From 8d23587f1d285e93983b4b7d1decea01c2fe2e9e Mon Sep 17 00:00:00 2001 From: sethah Date: Tue, 30 Jun 2015 16:28:25 -0700 Subject: [PATCH 107/122] [SPARK-7739] [MLLIB] Improve ChiSqSelector example code in user guide Author: sethah Closes #7029 from sethah/working_on_SPARK-7739 and squashes the following commits: ef96916 [sethah] Fixing some style issues efea1f8 [sethah] adding clarification to ChiSqSelector example --- docs/mllib-feature-extraction.md | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 83e937635a55b..a69e41e2a1936 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -384,7 +384,7 @@ data2 = labels.zip(normalizer2.transform(features)) [Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. The number of features to select can be tuned using a held-out validation set. ### ChiSqSelector -[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which are most closely related to the label. +[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which the class label depends on the most. This is akin to yielding the features with the most predictive power. #### Model Fitting @@ -405,7 +405,7 @@ Note that the user can also construct a `ChiSqSelectorModel` by hand by providin #### Example -The following example shows the basic use of ChiSqSelector. +The following example shows the basic use of ChiSqSelector. The data set used has a feature matrix consisting of greyscale values that vary from 0 to 255 for each feature.
@@ -419,10 +419,11 @@ import org.apache.spark.mllib.feature.ChiSqSelector // Load some data in libsvm format val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") // Discretize data in 16 equal bins since ChiSqSelector requires categorical features +// Even though features are doubles, the ChiSqSelector treats each unique value as a category val discretizedData = data.map { lp => - LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => x / 16 } ) ) + LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => (x / 16).floor } ) ) } -// Create ChiSqSelector that will select 50 features +// Create ChiSqSelector that will select top 50 of 692 features val selector = new ChiSqSelector(50) // Create ChiSqSelector model (selecting features) val transformer = selector.fit(discretizedData) @@ -451,19 +452,20 @@ JavaRDD points = MLUtils.loadLibSVMFile(sc.sc(), "data/mllib/sample_libsvm_data.txt").toJavaRDD().cache(); // Discretize data in 16 equal bins since ChiSqSelector requires categorical features +// Even though features are doubles, the ChiSqSelector treats each unique value as a category JavaRDD discretizedData = points.map( new Function() { @Override public LabeledPoint call(LabeledPoint lp) { final double[] discretizedFeatures = new double[lp.features().size()]; for (int i = 0; i < lp.features().size(); ++i) { - discretizedFeatures[i] = lp.features().apply(i) / 16; + discretizedFeatures[i] = Math.floor(lp.features().apply(i) / 16); } return new LabeledPoint(lp.label(), Vectors.dense(discretizedFeatures)); } }); -// Create ChiSqSelector that will select 50 features +// Create ChiSqSelector that will select top 50 of 692 features ChiSqSelector selector = new ChiSqSelector(50); // Create ChiSqSelector model (selecting features) final ChiSqSelectorModel transformer = selector.fit(discretizedData.rdd()); From 8133125ca0b83985e0c2aa2a6ad477556867e412 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 30 Jun 2015 16:54:51 -0700 Subject: [PATCH 108/122] [SPARK-8741] [SQL] Remove e and pi from DataFrame functions. Author: Reynold Xin Closes #7137 from rxin/SPARK-8741 and squashes the following commits: 32c7e75 [Reynold Xin] [SPARK-8741][SQL] Remove e and pi from DataFrame functions. --- .../scala/org/apache/spark/sql/functions.scala | 18 ------------------ .../spark/sql/DataFrameFunctionsSuite.scala | 8 -------- 2 files changed, 26 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6331fe61052ab..5767668dd339b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -989,15 +989,6 @@ object functions { */ def cosh(columnName: String): Column = cosh(Column(columnName)) - /** - * Returns the double value that is closer than any other to e, the base of the natural - * logarithms. - * - * @group math_funcs - * @since 1.5.0 - */ - def e(): Column = EulerNumber() - /** * Computes the exponential of the given value. * @@ -1191,15 +1182,6 @@ object functions { */ def log1p(columnName: String): Column = log1p(Column(columnName)) - /** - * Returns the double value that is closer than any other to pi, the ratio of the circumference - * of a circle to its diameter. - * - * @group math_funcs - * @since 1.5.0 - */ - def pi(): Column = Pi() - /** * Computes the logarithm of the given column in base 2. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 11a8767ead96c..7ae89bcb1b9cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -86,14 +86,6 @@ class DataFrameFunctionsSuite extends QueryTest { } test("constant functions") { - checkAnswer( - testData2.select(e()).limit(1), - Row(scala.math.E) - ) - checkAnswer( - testData2.select(pi()).limit(1), - Row(scala.math.Pi) - ) checkAnswer( ctx.sql("SELECT E()"), Row(scala.math.E) From ccdb05222a223187199183fd48e3a3313d536965 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Tue, 30 Jun 2015 16:59:44 -0700 Subject: [PATCH 109/122] [SPARK-8727] [SQL] Missing python api; md5, log2 Jira: https://issues.apache.org/jira/browse/SPARK-8727 Author: Tarek Auel Author: Tarek Auel Closes #7114 from tarekauel/missing-python and squashes the following commits: ef4c61b [Tarek Auel] [SPARK-8727] revert dataframe change 4029d4d [Tarek Auel] removed dataframe pi and e unit test 66f0d2b [Tarek Auel] removed pi and e from python api and dataframe api; added _to_java_column(col) for strlen 4d07318 [Tarek Auel] fixed python unit test 45f2bee [Tarek Auel] fixed result of pi and e c39f47b [Tarek Auel] add python api bd50a3a [Tarek Auel] add missing python functions --- python/pyspark/sql/functions.py | 65 ++++++++++++++++++++++++++------- 1 file changed, 52 insertions(+), 13 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 45ecd826bd3bd..4e2be88e9e3b9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -39,12 +39,15 @@ 'coalesce', 'countDistinct', 'explode', + 'log2', + 'md5', 'monotonicallyIncreasingId', 'rand', 'randn', 'sha1', 'sha2', 'sparkPartitionId', + 'strlen', 'struct', 'udf', 'when'] @@ -320,6 +323,19 @@ def explode(col): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def md5(col): + """Calculates the MD5 digest and returns the value as a 32 character hex string. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(md5('a').alias('hash')).collect() + [Row(hash=u'902fbdd2b1df0c4f70b4a5d23525e932')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.md5(_to_java_column(col)) + return Column(jc) + + @since(1.4) def monotonicallyIncreasingId(): """A column that generates monotonically increasing 64-bit integers. @@ -365,6 +381,19 @@ def randn(seed=None): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def sha1(col): + """Returns the hex string result of SHA-1. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() + [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.sha1(_to_java_column(col)) + return Column(jc) + + @ignore_unicode_prefix @since(1.5) def sha2(col, numBits): @@ -383,19 +412,6 @@ def sha2(col, numBits): return Column(jc) -@ignore_unicode_prefix -@since(1.5) -def sha1(col): - """Returns the hex string result of SHA-1. - - >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(sha1('a').alias('hash')).collect() - [Row(hash=u'3c01bdbb26f358bab27f267924aa2c9a03fcfdb8')] - """ - sc = SparkContext._active_spark_context - jc = sc._jvm.functions.sha1(_to_java_column(col)) - return Column(jc) - - @since(1.4) def sparkPartitionId(): """A column for partition ID of the Spark task. @@ -409,6 +425,18 @@ def sparkPartitionId(): return Column(sc._jvm.functions.sparkPartitionId()) +@ignore_unicode_prefix +@since(1.5) +def strlen(col): + """Calculates the length of a string expression. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(strlen('a').alias('length')).collect() + [Row(length=3)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.strlen(_to_java_column(col))) + + @ignore_unicode_prefix @since(1.4) def struct(*cols): @@ -471,6 +499,17 @@ def log(arg1, arg2=None): return Column(jc) +@since(1.5) +def log2(col): + """Returns the base-2 logarithm of the argument. + + >>> sqlContext.createDataFrame([(4,)], ['a']).select(log2('a').alias('log2')).collect() + [Row(log2=2.0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.log2(_to_java_column(col))) + + @since(1.4) def lag(col, count=1, default=None): """ From 3bee0f1466ddd69f26e95297b5e0d2398b6c6268 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 30 Jun 2015 17:39:55 -0700 Subject: [PATCH 110/122] [SPARK-6602][Core] Update Master, Worker, Client, AppClient and related classes to use RpcEndpoint This PR updates the rest Actors in core to RpcEndpoint. Because there is no `ActorSelection` in RpcEnv, I changes the logic of `registerWithMaster` in Worker and AppClient to avoid blocking the message loop. These changes need to be reviewed carefully. Author: zsxwing Closes #5392 from zsxwing/rpc-rewrite-part3 and squashes the following commits: 2de7bed [zsxwing] Merge branch 'master' into rpc-rewrite-part3 f12d943 [zsxwing] Address comments 9137b82 [zsxwing] Fix the code style e734c71 [zsxwing] Merge branch 'master' into rpc-rewrite-part3 2d24fb5 [zsxwing] Fix the code style 5a82374 [zsxwing] Merge branch 'master' into rpc-rewrite-part3 fa47110 [zsxwing] Merge branch 'master' into rpc-rewrite-part3 72304f0 [zsxwing] Update the error strategy for AkkaRpcEnv e56cb16 [zsxwing] Always send failure back to the sender a7b86e6 [zsxwing] Use JFuture for java.util.concurrent.Future aa34b9b [zsxwing] Fix the code style bd541e7 [zsxwing] Merge branch 'master' into rpc-rewrite-part3 25a84d8 [zsxwing] Use ThreadUtils 060ff31 [zsxwing] Merge branch 'master' into rpc-rewrite-part3 dbfc916 [zsxwing] Improve the docs and comments 837927e [zsxwing] Merge branch 'master' into rpc-rewrite-part3 5c27f97 [zsxwing] Merge branch 'master' into rpc-rewrite-part3 fadbb9e [zsxwing] Fix the code style 6637e3c [zsxwing] Merge remote-tracking branch 'origin/master' into rpc-rewrite-part3 7fdee0e [zsxwing] Fix the return type to ExecutorService and ScheduledExecutorService e8ad0a5 [zsxwing] Fix the code style 6b2a104 [zsxwing] Log error and use SparkExitCode.UNCAUGHT_EXCEPTION exit code fbf3194 [zsxwing] Add Utils.newDaemonSingleThreadExecutor and newDaemonSingleThreadScheduledExecutor b776817 [zsxwing] Update Master, Worker, Client, AppClient and related classes to use RpcEndpoint --- .../org/apache/spark/deploy/Client.scala | 156 ++++--- .../apache/spark/deploy/DeployMessage.scala | 22 +- .../spark/deploy/LocalSparkCluster.scala | 26 +- .../spark/deploy/client/AppClient.scala | 199 +++++---- .../spark/deploy/client/TestClient.scala | 10 +- .../spark/deploy/master/ApplicationInfo.scala | 5 +- .../apache/spark/deploy/master/Master.scala | 392 +++++++++--------- .../spark/deploy/master/MasterMessages.scala | 2 +- .../spark/deploy/master/WorkerInfo.scala | 6 +- .../master/ZooKeeperLeaderElectionAgent.scala | 3 - .../deploy/master/ui/ApplicationPage.scala | 9 +- .../spark/deploy/master/ui/MasterPage.scala | 14 +- .../spark/deploy/master/ui/MasterWebUI.scala | 4 +- .../deploy/rest/StandaloneRestServer.scala | 35 +- .../spark/deploy/worker/DriverRunner.scala | 6 +- .../spark/deploy/worker/ExecutorRunner.scala | 8 +- .../apache/spark/deploy/worker/Worker.scala | 318 +++++++++----- .../spark/deploy/worker/WorkerWatcher.scala | 1 - .../spark/deploy/worker/ui/WorkerPage.scala | 11 +- .../scala/org/apache/spark/rpc/RpcEnv.scala | 2 + .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 8 +- .../cluster/SparkDeploySchedulerBackend.scala | 2 +- .../spark/deploy/master/MasterSuite.scala | 56 +-- .../rest/StandaloneRestSubmitSuite.scala | 54 +-- .../deploy/worker/WorkerWatcherSuite.scala | 15 +- .../apache/spark/rpc/RpcAddressSuite.scala | 55 +++ .../spark/rpc/akka/AkkaRpcEnvSuite.scala | 20 +- 27 files changed, 806 insertions(+), 633 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 848b62f9de71b..71f7e2129116f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -18,17 +18,17 @@ package org.apache.spark.deploy import scala.collection.mutable.HashSet -import scala.concurrent._ +import scala.concurrent.ExecutionContext +import scala.reflect.ClassTag +import scala.util.{Failure, Success} -import akka.actor._ -import akka.pattern.ask -import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} import org.apache.log4j.{Level, Logger} +import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils} +import org.apache.spark.util.{ThreadUtils, SparkExitCode, Utils} /** * Proxy that relays messages to the driver. @@ -36,20 +36,30 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, Utils} * We currently don't support retry if submission fails. In HA mode, client will submit request to * all masters and see which one could handle it. */ -private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) - extends Actor with ActorLogReceive with Logging { - - private val masterActors = driverArgs.masters.map { m => - context.actorSelection(Master.toAkkaUrl(m, AkkaUtils.protocol(context.system))) - } - private val lostMasters = new HashSet[Address] - private var activeMasterActor: ActorSelection = null - - val timeout = RpcUtils.askTimeout(conf) - - override def preStart(): Unit = { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - +private class ClientEndpoint( + override val rpcEnv: RpcEnv, + driverArgs: ClientArguments, + masterEndpoints: Seq[RpcEndpointRef], + conf: SparkConf) + extends ThreadSafeRpcEndpoint with Logging { + + // A scheduled executor used to send messages at the specified time. + private val forwardMessageThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("client-forward-message") + // Used to provide the implicit parameter of `Future` methods. + private val forwardMessageExecutionContext = + ExecutionContext.fromExecutor(forwardMessageThread, + t => t match { + case ie: InterruptedException => // Exit normally + case e: Throwable => + logError(e.getMessage, e) + System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) + }) + + private val lostMasters = new HashSet[RpcAddress] + private var activeMasterEndpoint: RpcEndpointRef = null + + override def onStart(): Unit = { driverArgs.cmd match { case "launch" => // TODO: We could add an env variable here and intercept it in `sc.addJar` that would @@ -82,29 +92,37 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) driverArgs.cores, driverArgs.supervise, command) - - // This assumes only one Master is active at a time - for (masterActor <- masterActors) { - masterActor ! RequestSubmitDriver(driverDescription) - } + ayncSendToMasterAndForwardReply[SubmitDriverResponse]( + RequestSubmitDriver(driverDescription)) case "kill" => val driverId = driverArgs.driverId - // This assumes only one Master is active at a time - for (masterActor <- masterActors) { - masterActor ! RequestKillDriver(driverId) - } + ayncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId)) + } + } + + /** + * Send the message to master and forward the reply to self asynchronously. + */ + private def ayncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = { + for (masterEndpoint <- masterEndpoints) { + masterEndpoint.ask[T](message).onComplete { + case Success(v) => self.send(v) + case Failure(e) => + logWarning(s"Error sending messages to master $masterEndpoint", e) + }(forwardMessageExecutionContext) } } /* Find out driver status then exit the JVM */ def pollAndReportStatus(driverId: String) { + // Since ClientEndpoint is the only RpcEndpoint in the process, blocking the event loop thread + // is fine. println("... waiting before polling master for driver state") Thread.sleep(5000) println("... polling master for driver state") - val statusFuture = (activeMasterActor ? RequestDriverStatus(driverId))(timeout) - .mapTo[DriverStatusResponse] - val statusResponse = Await.result(statusFuture, timeout) + val statusResponse = + activeMasterEndpoint.askWithRetry[DriverStatusResponse](RequestDriverStatus(driverId)) statusResponse.found match { case false => println(s"ERROR: Cluster master did not recognize $driverId") @@ -127,50 +145,62 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) } } - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { - case SubmitDriverResponse(success, driverId, message) => + case SubmitDriverResponse(master, success, driverId, message) => println(message) if (success) { - activeMasterActor = context.actorSelection(sender.path) + activeMasterEndpoint = master pollAndReportStatus(driverId.get) } else if (!Utils.responseFromBackup(message)) { System.exit(-1) } - case KillDriverResponse(driverId, success, message) => + case KillDriverResponse(master, driverId, success, message) => println(message) if (success) { - activeMasterActor = context.actorSelection(sender.path) + activeMasterEndpoint = master pollAndReportStatus(driverId) } else if (!Utils.responseFromBackup(message)) { System.exit(-1) } + } - case DisassociatedEvent(_, remoteAddress, _) => - if (!lostMasters.contains(remoteAddress)) { - println(s"Error connecting to master $remoteAddress.") - lostMasters += remoteAddress - // Note that this heuristic does not account for the fact that a Master can recover within - // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This - // is not currently a concern, however, because this client does not retry submissions. - if (lostMasters.size >= masterActors.size) { - println("No master is available, exiting.") - System.exit(-1) - } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (!lostMasters.contains(remoteAddress)) { + println(s"Error connecting to master $remoteAddress.") + lostMasters += remoteAddress + // Note that this heuristic does not account for the fact that a Master can recover within + // the lifetime of this client. Thus, once a Master is lost it is lost to us forever. This + // is not currently a concern, however, because this client does not retry submissions. + if (lostMasters.size >= masterEndpoints.size) { + println("No master is available, exiting.") + System.exit(-1) } + } + } - case AssociationErrorEvent(cause, _, remoteAddress, _, _) => - if (!lostMasters.contains(remoteAddress)) { - println(s"Error connecting to master ($remoteAddress).") - println(s"Cause was: $cause") - lostMasters += remoteAddress - if (lostMasters.size >= masterActors.size) { - println("No master is available, exiting.") - System.exit(-1) - } + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + if (!lostMasters.contains(remoteAddress)) { + println(s"Error connecting to master ($remoteAddress).") + println(s"Cause was: $cause") + lostMasters += remoteAddress + if (lostMasters.size >= masterEndpoints.size) { + println("No master is available, exiting.") + System.exit(-1) } + } + } + + override def onError(cause: Throwable): Unit = { + println(s"Error processing messages, exiting.") + cause.printStackTrace() + System.exit(-1) + } + + override def onStop(): Unit = { + forwardMessageThread.shutdownNow() } } @@ -194,15 +224,13 @@ object Client { conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING")) Logger.getRootLogger.setLevel(driverArgs.logLevel) - val (actorSystem, _) = AkkaUtils.createActorSystem( - "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) + val rpcEnv = + RpcEnv.create("driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) - // Verify driverArgs.master is a valid url so that we can use it in ClientActor safely - for (m <- driverArgs.masters) { - Master.toAkkaUrl(m, AkkaUtils.protocol(actorSystem)) - } - actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) + val masterEndpoints = driverArgs.masters.map(RpcAddress.fromSparkURL). + map(rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, _, Master.ENDPOINT_NAME)) + rpcEnv.setupEndpoint("client", new ClientEndpoint(rpcEnv, driverArgs, masterEndpoints, conf)) - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala index 9db6fd1ac4dbe..12727de9b4cf3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala @@ -24,11 +24,12 @@ import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.RecoveryState.MasterState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[deploy] sealed trait DeployMessage extends Serializable -/** Contains messages sent between Scheduler actor nodes. */ +/** Contains messages sent between Scheduler endpoint nodes. */ private[deploy] object DeployMessages { // Worker to Master @@ -37,6 +38,7 @@ private[deploy] object DeployMessages { id: String, host: String, port: Int, + worker: RpcEndpointRef, cores: Int, memory: Int, webUiPort: Int, @@ -63,11 +65,11 @@ private[deploy] object DeployMessages { case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription], driverIds: Seq[String]) - case class Heartbeat(workerId: String) extends DeployMessage + case class Heartbeat(workerId: String, worker: RpcEndpointRef) extends DeployMessage // Master to Worker - case class RegisteredWorker(masterUrl: String, masterWebUiUrl: String) extends DeployMessage + case class RegisteredWorker(master: RpcEndpointRef, masterWebUiUrl: String) extends DeployMessage case class RegisterWorkerFailed(message: String) extends DeployMessage @@ -92,13 +94,13 @@ private[deploy] object DeployMessages { // Worker internal - case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders + case object WorkDirCleanup // Sent to Worker endpoint periodically for cleaning up app folders case object ReregisterWithMaster // used when a worker attempts to reconnect to a master // AppClient to Master - case class RegisterApplication(appDescription: ApplicationDescription) + case class RegisterApplication(appDescription: ApplicationDescription, driver: RpcEndpointRef) extends DeployMessage case class UnregisterApplication(appId: String) @@ -107,7 +109,7 @@ private[deploy] object DeployMessages { // Master to AppClient - case class RegisteredApplication(appId: String, masterUrl: String) extends DeployMessage + case class RegisteredApplication(appId: String, master: RpcEndpointRef) extends DeployMessage // TODO(matei): replace hostPort with host case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) { @@ -123,12 +125,14 @@ private[deploy] object DeployMessages { case class RequestSubmitDriver(driverDescription: DriverDescription) extends DeployMessage - case class SubmitDriverResponse(success: Boolean, driverId: Option[String], message: String) + case class SubmitDriverResponse( + master: RpcEndpointRef, success: Boolean, driverId: Option[String], message: String) extends DeployMessage case class RequestKillDriver(driverId: String) extends DeployMessage - case class KillDriverResponse(driverId: String, success: Boolean, message: String) + case class KillDriverResponse( + master: RpcEndpointRef, driverId: String, success: Boolean, message: String) extends DeployMessage case class RequestDriverStatus(driverId: String) extends DeployMessage @@ -142,7 +146,7 @@ private[deploy] object DeployMessages { // Master to Worker & AppClient - case class MasterChanged(masterUrl: String, masterWebUiUrl: String) + case class MasterChanged(master: RpcEndpointRef, masterWebUiUrl: String) // MasterWebUI To Master diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 0550f00a172ab..53356addf6edb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -19,8 +19,7 @@ package org.apache.spark.deploy import scala.collection.mutable.ArrayBuffer -import akka.actor.ActorSystem - +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.master.Master @@ -41,8 +40,8 @@ class LocalSparkCluster( extends Logging { private val localHostname = Utils.localHostName() - private val masterActorSystems = ArrayBuffer[ActorSystem]() - private val workerActorSystems = ArrayBuffer[ActorSystem]() + private val masterRpcEnvs = ArrayBuffer[RpcEnv]() + private val workerRpcEnvs = ArrayBuffer[RpcEnv]() // exposed for testing var masterWebUIPort = -1 @@ -55,18 +54,17 @@ class LocalSparkCluster( .set("spark.shuffle.service.enabled", "false") /* Start the Master */ - val (masterSystem, masterPort, webUiPort, _) = - Master.startSystemAndActor(localHostname, 0, 0, _conf) + val (rpcEnv, webUiPort, _) = Master.startRpcEnvAndEndpoint(localHostname, 0, 0, _conf) masterWebUIPort = webUiPort - masterActorSystems += masterSystem - val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + masterPort + masterRpcEnvs += rpcEnv + val masterUrl = "spark://" + Utils.localHostNameForURI() + ":" + rpcEnv.address.port val masters = Array(masterUrl) /* Start the Workers */ for (workerNum <- 1 to numWorkers) { - val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, + val workerEnv = Worker.startRpcEnvAndEndpoint(localHostname, 0, 0, coresPerWorker, memoryPerWorker, masters, null, Some(workerNum), _conf) - workerActorSystems += workerSystem + workerRpcEnvs += workerEnv } masters @@ -77,11 +75,11 @@ class LocalSparkCluster( // Stop the workers before the master so they don't get upset that it disconnected // TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors! // This is unfortunate, but for now we just comment it out. - workerActorSystems.foreach(_.shutdown()) + workerRpcEnvs.foreach(_.shutdown()) // workerActorSystems.foreach(_.awaitTermination()) - masterActorSystems.foreach(_.shutdown()) + masterRpcEnvs.foreach(_.shutdown()) // masterActorSystems.foreach(_.awaitTermination()) - masterActorSystems.clear() - workerActorSystems.clear() + masterRpcEnvs.clear() + workerRpcEnvs.clear() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 43c8a934c311a..79b251e7e62fe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -17,20 +17,17 @@ package org.apache.spark.deploy.client -import java.util.concurrent.TimeoutException +import java.util.concurrent._ +import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} -import scala.concurrent.Await -import scala.concurrent.duration._ - -import akka.actor._ -import akka.pattern.ask -import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} +import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils} +import org.apache.spark.rpc._ +import org.apache.spark.util.{ThreadUtils, Utils} /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, @@ -40,98 +37,143 @@ import org.apache.spark.util.{ActorLogReceive, RpcUtils, Utils, AkkaUtils} * @param masterUrls Each url should look like spark://host:port. */ private[spark] class AppClient( - actorSystem: ActorSystem, + rpcEnv: RpcEnv, masterUrls: Array[String], appDescription: ApplicationDescription, listener: AppClientListener, conf: SparkConf) extends Logging { - private val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) + private val masterRpcAddresses = masterUrls.map(RpcAddress.fromSparkURL(_)) - private val REGISTRATION_TIMEOUT = 20.seconds + private val REGISTRATION_TIMEOUT_SECONDS = 20 private val REGISTRATION_RETRIES = 3 - private var masterAddress: Address = null - private var actor: ActorRef = null + private var endpoint: RpcEndpointRef = null private var appId: String = null - private var registered = false - private var activeMasterUrl: String = null + @volatile private var registered = false + + private class ClientEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint + with Logging { + + private var master: Option[RpcEndpointRef] = None + // To avoid calling listener.disconnected() multiple times + private var alreadyDisconnected = false + @volatile private var alreadyDead = false // To avoid calling listener.dead() multiple times + @volatile private var registerMasterFutures: Array[JFuture[_]] = null + @volatile private var registrationRetryTimer: JScheduledFuture[_] = null + + // A thread pool for registering with masters. Because registering with a master is a blocking + // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same + // time so that we can register with all masters. + private val registerMasterThreadPool = new ThreadPoolExecutor( + 0, + masterRpcAddresses.size, // Make sure we can register with all masters at the same time + 60L, TimeUnit.SECONDS, + new SynchronousQueue[Runnable](), + ThreadUtils.namedThreadFactory("appclient-register-master-threadpool")) - private class ClientActor extends Actor with ActorLogReceive with Logging { - var master: ActorSelection = null - var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times - var alreadyDead = false // To avoid calling listener.dead() multiple times - var registrationRetryTimer: Option[Cancellable] = None + // A scheduled executor for scheduling the registration actions + private val registrationRetryThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("appclient-registration-retry-thread") - override def preStart() { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + override def onStart(): Unit = { try { - registerWithMaster() + registerWithMaster(1) } catch { case e: Exception => logWarning("Failed to connect to master", e) markDisconnected() - context.stop(self) + stop() } } - def tryRegisterAllMasters() { - for (masterAkkaUrl <- masterAkkaUrls) { - logInfo("Connecting to master " + masterAkkaUrl + "...") - val actor = context.actorSelection(masterAkkaUrl) - actor ! RegisterApplication(appDescription) + /** + * Register with all masters asynchronously and returns an array `Future`s for cancellation. + */ + private def tryRegisterAllMasters(): Array[JFuture[_]] = { + for (masterAddress <- masterRpcAddresses) yield { + registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = try { + if (registered) { + return + } + logInfo("Connecting to master " + masterAddress.toSparkURL + "...") + val masterRef = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterRef.send(RegisterApplication(appDescription, self)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + }) } } - def registerWithMaster() { - tryRegisterAllMasters() - import context.dispatcher - var retries = 0 - registrationRetryTimer = Some { - context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) { + /** + * Register with all masters asynchronously. It will call `registerWithMaster` every + * REGISTRATION_TIMEOUT_SECONDS seconds until exceeding REGISTRATION_RETRIES times. + * Once we connect to a master successfully, all scheduling work and Futures will be cancelled. + * + * nthRetry means this is the nth attempt to register with master. + */ + private def registerWithMaster(nthRetry: Int) { + registerMasterFutures = tryRegisterAllMasters() + registrationRetryTimer = registrationRetryThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = { Utils.tryOrExit { - retries += 1 if (registered) { - registrationRetryTimer.foreach(_.cancel()) - } else if (retries >= REGISTRATION_RETRIES) { + registerMasterFutures.foreach(_.cancel(true)) + registerMasterThreadPool.shutdownNow() + } else if (nthRetry >= REGISTRATION_RETRIES) { markDead("All masters are unresponsive! Giving up.") } else { - tryRegisterAllMasters() + registerMasterFutures.foreach(_.cancel(true)) + registerWithMaster(nthRetry + 1) } } } - } + }, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS) } - def changeMaster(url: String) { - // activeMasterUrl is a valid Spark url since we receive it from master. - activeMasterUrl = url - master = context.actorSelection( - Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(actorSystem))) - masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(actorSystem)) + /** + * Send a message to the current master. If we have not yet registered successfully with any + * master, the message will be dropped. + */ + private def sendToMaster(message: Any): Unit = { + master match { + case Some(masterRef) => masterRef.send(message) + case None => logWarning(s"Drop $message because has not yet connected to master") + } } - private def isPossibleMaster(remoteUrl: Address) = { - masterAkkaUrls.map(AddressFromURIString(_).hostPort).contains(remoteUrl.hostPort) + private def isPossibleMaster(remoteAddress: RpcAddress): Boolean = { + masterRpcAddresses.contains(remoteAddress) } - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisteredApplication(appId_, masterUrl) => + override def receive: PartialFunction[Any, Unit] = { + case RegisteredApplication(appId_, masterRef) => + // FIXME How to handle the following cases? + // 1. A master receives multiple registrations and sends back multiple + // RegisteredApplications due to an unstable network. + // 2. Receive multiple RegisteredApplication from different masters because the master is + // changing. appId = appId_ registered = true - changeMaster(masterUrl) + master = Some(masterRef) listener.connected(appId) case ApplicationRemoved(message) => markDead("Master removed our application: %s".format(message)) - context.stop(self) + stop() case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) - master ! ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None) + // FIXME if changing master and `ExecutorAdded` happen at the same time (the order is not + // guaranteed), `ExecutorStateChanged` may be sent to a dead master. + sendToMaster(ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => @@ -142,24 +184,32 @@ private[spark] class AppClient( listener.executorRemoved(fullId, message.getOrElse(""), exitStatus) } - case MasterChanged(masterUrl, masterWebUiUrl) => - logInfo("Master has changed, new master is at " + masterUrl) - changeMaster(masterUrl) + case MasterChanged(masterRef, masterWebUiUrl) => + logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) + master = Some(masterRef) alreadyDisconnected = false - sender ! MasterChangeAcknowledged(appId) + masterRef.send(MasterChangeAcknowledged(appId)) + } - case DisassociatedEvent(_, address, _) if address == masterAddress => + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case StopAppClient => + markDead("Application has been stopped.") + sendToMaster(UnregisterApplication(appId)) + context.reply(true) + stop() + } + + override def onDisconnected(address: RpcAddress): Unit = { + if (master.exists(_.address == address)) { logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() + } + } - case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => + override def onNetworkError(cause: Throwable, address: RpcAddress): Unit = { + if (isPossibleMaster(address)) { logWarning(s"Could not connect to $address: $cause") - - case StopAppClient => - markDead("Application has been stopped.") - master ! UnregisterApplication(appId) - sender ! true - context.stop(self) + } } /** @@ -179,28 +229,31 @@ private[spark] class AppClient( } } - override def postStop() { - registrationRetryTimer.foreach(_.cancel()) + override def onStop(): Unit = { + if (registrationRetryTimer != null) { + registrationRetryTimer.cancel(true) + } + registrationRetryThread.shutdownNow() + registerMasterFutures.foreach(_.cancel(true)) + registerMasterThreadPool.shutdownNow() } } def start() { // Just launch an actor; it will call back into the listener. - actor = actorSystem.actorOf(Props(new ClientActor)) + endpoint = rpcEnv.setupEndpoint("AppClient", new ClientEndpoint(rpcEnv)) } def stop() { - if (actor != null) { + if (endpoint != null) { try { - val timeout = RpcUtils.askTimeout(conf) - val future = actor.ask(StopAppClient)(timeout) - Await.result(future, timeout) + endpoint.askWithRetry[Boolean](StopAppClient) } catch { case e: TimeoutException => logInfo("Stop request to Master timed out; it may already be shut down.") } - actor = null + endpoint = null } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 40835b9550586..1c79089303e3d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -17,9 +17,10 @@ package org.apache.spark.deploy.client +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, Command} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils private[spark] object TestClient { @@ -46,13 +47,12 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val conf = new SparkConf - val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localHostName(), 0, - conf = conf, securityManager = new SecurityManager(conf)) + val rpcEnv = RpcEnv.create("spark", Utils.localHostName(), 0, conf, new SecurityManager(conf)) val desc = new ApplicationDescription("TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") val listener = new TestListener - val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf) + val client = new AppClient(rpcEnv, Array(url), desc, listener, new SparkConf) client.start() - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index 1620e95bea218..aa54ed9360f36 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -22,10 +22,9 @@ import java.util.Date import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import akka.actor.ActorRef - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.ApplicationDescription +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] class ApplicationInfo( @@ -33,7 +32,7 @@ private[spark] class ApplicationInfo( val id: String, val desc: ApplicationDescription, val submitDate: Date, - val driver: ActorRef, + val driver: RpcEndpointRef, defaultCores: Int) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index fccceb3ea528b..3e7c16722805e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -21,20 +21,18 @@ import java.io.FileNotFoundException import java.net.URLEncoder import java.text.SimpleDateFormat import java.util.Date +import java.util.concurrent.{ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await -import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import akka.actor._ -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.Serialization import akka.serialization.SerializationExtension import org.apache.hadoop.fs.Path +import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState, SparkHadoopUtil} @@ -47,23 +45,27 @@ import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, RpcUtils, SignalLogger, Utils} +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} private[master] class Master( - host: String, - port: Int, + override val rpcEnv: RpcEnv, + address: RpcAddress, webUiPort: Int, val securityMgr: SecurityManager, val conf: SparkConf) - extends Actor with ActorLogReceive with Logging with LeaderElectable { + extends ThreadSafeRpcEndpoint with Logging with LeaderElectable { - import context.dispatcher // to use Akka's scheduler.schedule() + private val forwardMessageThread = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("master-forward-message-thread") + + // TODO Remove it once we don't use akka.serialization.Serialization + private val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem private val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) - private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs + private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - private val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000 + private val WORKER_TIMEOUT_MS = conf.getLong("spark.worker.timeout", 60) * 1000 private val RETAINED_APPLICATIONS = conf.getInt("spark.deploy.retainedApplications", 200) private val RETAINED_DRIVERS = conf.getInt("spark.deploy.retainedDrivers", 200) private val REAPER_ITERATIONS = conf.getInt("spark.dead.worker.persistence", 15) @@ -75,10 +77,10 @@ private[master] class Master( val apps = new HashSet[ApplicationInfo] private val idToWorker = new HashMap[String, WorkerInfo] - private val addressToWorker = new HashMap[Address, WorkerInfo] + private val addressToWorker = new HashMap[RpcAddress, WorkerInfo] - private val actorToApp = new HashMap[ActorRef, ApplicationInfo] - private val addressToApp = new HashMap[Address, ApplicationInfo] + private val endpointToApp = new HashMap[RpcEndpointRef, ApplicationInfo] + private val addressToApp = new HashMap[RpcAddress, ApplicationInfo] private val completedApps = new ArrayBuffer[ApplicationInfo] private var nextAppNumber = 0 private val appIdToUI = new HashMap[String, SparkUI] @@ -89,21 +91,22 @@ private[master] class Master( private val waitingDrivers = new ArrayBuffer[DriverInfo] private var nextDriverNumber = 0 - Utils.checkHost(host, "Expected hostname") + Utils.checkHost(address.host, "Expected hostname") private val masterMetricsSystem = MetricsSystem.createMetricsSystem("master", conf, securityMgr) private val applicationMetricsSystem = MetricsSystem.createMetricsSystem("applications", conf, securityMgr) private val masterSource = new MasterSource(this) - private val webUi = new MasterWebUI(this, webUiPort) + // After onStart, webUi will be set + private var webUi: MasterWebUI = null private val masterPublicAddress = { val envVar = conf.getenv("SPARK_PUBLIC_DNS") - if (envVar != null) envVar else host + if (envVar != null) envVar else address.host } - private val masterUrl = "spark://" + host + ":" + port + private val masterUrl = address.toSparkURL private var masterWebUiUrl: String = _ private var state = RecoveryState.STANDBY @@ -112,7 +115,9 @@ private[master] class Master( private var leaderElectionAgent: LeaderElectionAgent = _ - private var recoveryCompletionTask: Cancellable = _ + private var recoveryCompletionTask: ScheduledFuture[_] = _ + + private var checkForWorkerTimeOutTask: ScheduledFuture[_] = _ // As a temporary workaround before better ways of configuring memory, we allow users to set // a flag that will perform round-robin scheduling across the nodes (spreading out each app @@ -130,20 +135,23 @@ private[master] class Master( private val restServer = if (restServerEnabled) { val port = conf.getInt("spark.master.rest.port", 6066) - Some(new StandaloneRestServer(host, port, conf, self, masterUrl)) + Some(new StandaloneRestServer(address.host, port, conf, self, masterUrl)) } else { None } private val restServerBoundPort = restServer.map(_.start()) - override def preStart() { + override def onStart(): Unit = { logInfo("Starting Spark master at " + masterUrl) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + webUi = new MasterWebUI(this, webUiPort) webUi.bind() masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort - context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut) + checkForWorkerTimeOutTask = forwardMessageThread.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(CheckForWorkerTimeOut) + } + }, 0, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) masterMetricsSystem.registerSource(masterSource) masterMetricsSystem.start() @@ -157,16 +165,16 @@ private[master] class Master( case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") val zkFactory = - new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(context.system)) + new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(actorSystem)) (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => val fsFactory = - new FileSystemRecoveryModeFactory(conf, SerializationExtension(context.system)) + new FileSystemRecoveryModeFactory(conf, SerializationExtension(actorSystem)) (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) case "CUSTOM" => val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory")) val factory = clazz.getConstructor(classOf[SparkConf], classOf[Serialization]) - .newInstance(conf, SerializationExtension(context.system)) + .newInstance(conf, SerializationExtension(actorSystem)) .asInstanceOf[StandaloneRecoveryModeFactory] (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => @@ -176,18 +184,17 @@ private[master] class Master( leaderElectionAgent = leaderElectionAgent_ } - override def preRestart(reason: Throwable, message: Option[Any]) { - super.preRestart(reason, message) // calls postStop()! - logError("Master actor restarted due to exception", reason) - } - - override def postStop() { + override def onStop() { masterMetricsSystem.report() applicationMetricsSystem.report() // prevent the CompleteRecovery message sending to restarted master if (recoveryCompletionTask != null) { - recoveryCompletionTask.cancel() + recoveryCompletionTask.cancel(true) } + if (checkForWorkerTimeOutTask != null) { + checkForWorkerTimeOutTask.cancel(true) + } + forwardMessageThread.shutdownNow() webUi.stop() restServer.foreach(_.stop()) masterMetricsSystem.stop() @@ -197,14 +204,14 @@ private[master] class Master( } override def electedLeader() { - self ! ElectedLeader + self.send(ElectedLeader) } override def revokedLeadership() { - self ! RevokedLeadership + self.send(RevokedLeadership) } - override def receiveWithLogging: PartialFunction[Any, Unit] = { + override def receive: PartialFunction[Any, Unit] = { case ElectedLeader => { val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { @@ -215,8 +222,11 @@ private[master] class Master( logInfo("I have been elected leader! New state: " + state) if (state == RecoveryState.RECOVERING) { beginRecovery(storedApps, storedDrivers, storedWorkers) - recoveryCompletionTask = context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis, self, - CompleteRecovery) + recoveryCompletionTask = forwardMessageThread.schedule(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(CompleteRecovery) + } + }, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) } } @@ -227,111 +237,42 @@ private[master] class Master( System.exit(0) } - case RegisterWorker(id, workerHost, workerPort, cores, memory, workerUiPort, publicAddress) => - { + case RegisterWorker( + id, workerHost, workerPort, workerRef, cores, memory, workerUiPort, publicAddress) => { logInfo("Registering worker %s:%d with %d cores, %s RAM".format( workerHost, workerPort, cores, Utils.megabytesToString(memory))) if (state == RecoveryState.STANDBY) { // ignore, don't send response } else if (idToWorker.contains(id)) { - sender ! RegisterWorkerFailed("Duplicate worker ID") + workerRef.send(RegisterWorkerFailed("Duplicate worker ID")) } else { val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, - sender, workerUiPort, publicAddress) + workerRef, workerUiPort, publicAddress) if (registerWorker(worker)) { persistenceEngine.addWorker(worker) - sender ! RegisteredWorker(masterUrl, masterWebUiUrl) + workerRef.send(RegisteredWorker(self, masterWebUiUrl)) schedule() } else { - val workerAddress = worker.actor.path.address + val workerAddress = worker.endpoint.address logWarning("Worker registration failed. Attempted to re-register worker at same " + "address: " + workerAddress) - sender ! RegisterWorkerFailed("Attempted to re-register worker at same address: " - + workerAddress) + workerRef.send(RegisterWorkerFailed("Attempted to re-register worker at same address: " + + workerAddress)) } } } - case RequestSubmitDriver(description) => { - if (state != RecoveryState.ALIVE) { - val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + - "Can only accept driver submissions in ALIVE state." - sender ! SubmitDriverResponse(false, None, msg) - } else { - logInfo("Driver submitted " + description.command.mainClass) - val driver = createDriver(description) - persistenceEngine.addDriver(driver) - waitingDrivers += driver - drivers.add(driver) - schedule() - - // TODO: It might be good to instead have the submission client poll the master to determine - // the current status of the driver. For now it's simply "fire and forget". - - sender ! SubmitDriverResponse(true, Some(driver.id), - s"Driver successfully submitted as ${driver.id}") - } - } - - case RequestKillDriver(driverId) => { - if (state != RecoveryState.ALIVE) { - val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + - s"Can only kill drivers in ALIVE state." - sender ! KillDriverResponse(driverId, success = false, msg) - } else { - logInfo("Asked to kill driver " + driverId) - val driver = drivers.find(_.id == driverId) - driver match { - case Some(d) => - if (waitingDrivers.contains(d)) { - waitingDrivers -= d - self ! DriverStateChanged(driverId, DriverState.KILLED, None) - } else { - // We just notify the worker to kill the driver here. The final bookkeeping occurs - // on the return path when the worker submits a state change back to the master - // to notify it that the driver was successfully killed. - d.worker.foreach { w => - w.actor ! KillDriver(driverId) - } - } - // TODO: It would be nice for this to be a synchronous response - val msg = s"Kill request for $driverId submitted" - logInfo(msg) - sender ! KillDriverResponse(driverId, success = true, msg) - case None => - val msg = s"Driver $driverId has already finished or does not exist" - logWarning(msg) - sender ! KillDriverResponse(driverId, success = false, msg) - } - } - } - - case RequestDriverStatus(driverId) => { - if (state != RecoveryState.ALIVE) { - val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + - "Can only request driver status in ALIVE state." - sender ! DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg))) - } else { - (drivers ++ completedDrivers).find(_.id == driverId) match { - case Some(driver) => - sender ! DriverStatusResponse(found = true, Some(driver.state), - driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception) - case None => - sender ! DriverStatusResponse(found = false, None, None, None, None) - } - } - } - - case RegisterApplication(description) => { + case RegisterApplication(description, driver) => { + // TODO Prevent repeated registrations from some driver if (state == RecoveryState.STANDBY) { // ignore, don't send response } else { logInfo("Registering app " + description.name) - val app = createApplication(description, sender) + val app = createApplication(description, driver) registerApplication(app) logInfo("Registered app " + description.name + " with ID " + app.id) persistenceEngine.addApplication(app) - sender ! RegisteredApplication(app.id, masterUrl) + driver.send(RegisteredApplication(app.id, self)) schedule() } } @@ -343,7 +284,7 @@ private[master] class Master( val appInfo = idToApp(appId) exec.state = state if (state == ExecutorState.RUNNING) { appInfo.resetRetryCount() } - exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus) + exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus)) if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app logInfo(s"Removing executor ${exec.fullId} because it is $state") @@ -384,7 +325,7 @@ private[master] class Master( } } - case Heartbeat(workerId) => { + case Heartbeat(workerId, worker) => { idToWorker.get(workerId) match { case Some(workerInfo) => workerInfo.lastHeartbeat = System.currentTimeMillis() @@ -392,7 +333,7 @@ private[master] class Master( if (workers.map(_.id).contains(workerId)) { logWarning(s"Got heartbeat from unregistered worker $workerId." + " Asking it to re-register.") - sender ! ReconnectWorker(masterUrl) + worker.send(ReconnectWorker(masterUrl)) } else { logWarning(s"Got heartbeat from unregistered worker $workerId." + " This worker was never registered, so ignoring the heartbeat.") @@ -444,30 +385,103 @@ private[master] class Master( logInfo(s"Received unregister request from application $applicationId") idToApp.get(applicationId).foreach(finishApplication) - case DisassociatedEvent(_, address, _) => { - // The disconnected client could've been either a worker or an app; remove whichever it was - logInfo(s"$address got disassociated, removing it.") - addressToWorker.get(address).foreach(removeWorker) - addressToApp.get(address).foreach(finishApplication) - if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } + case CheckForWorkerTimeOut => { + timeOutDeadWorkers() } + } - case RequestMasterState => { - sender ! MasterStateResponse( - host, port, restServerBoundPort, - workers.toArray, apps.toArray, completedApps.toArray, - drivers.toArray, completedDrivers.toArray, state) + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestSubmitDriver(description) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + "Can only accept driver submissions in ALIVE state." + context.reply(SubmitDriverResponse(self, false, None, msg)) + } else { + logInfo("Driver submitted " + description.command.mainClass) + val driver = createDriver(description) + persistenceEngine.addDriver(driver) + waitingDrivers += driver + drivers.add(driver) + schedule() + + // TODO: It might be good to instead have the submission client poll the master to determine + // the current status of the driver. For now it's simply "fire and forget". + + context.reply(SubmitDriverResponse(self, true, Some(driver.id), + s"Driver successfully submitted as ${driver.id}")) + } } - case CheckForWorkerTimeOut => { - timeOutDeadWorkers() + case RequestKillDriver(driverId) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + s"Can only kill drivers in ALIVE state." + context.reply(KillDriverResponse(self, driverId, success = false, msg)) + } else { + logInfo("Asked to kill driver " + driverId) + val driver = drivers.find(_.id == driverId) + driver match { + case Some(d) => + if (waitingDrivers.contains(d)) { + waitingDrivers -= d + self.send(DriverStateChanged(driverId, DriverState.KILLED, None)) + } else { + // We just notify the worker to kill the driver here. The final bookkeeping occurs + // on the return path when the worker submits a state change back to the master + // to notify it that the driver was successfully killed. + d.worker.foreach { w => + w.endpoint.send(KillDriver(driverId)) + } + } + // TODO: It would be nice for this to be a synchronous response + val msg = s"Kill request for $driverId submitted" + logInfo(msg) + context.reply(KillDriverResponse(self, driverId, success = true, msg)) + case None => + val msg = s"Driver $driverId has already finished or does not exist" + logWarning(msg) + context.reply(KillDriverResponse(self, driverId, success = false, msg)) + } + } + } + + case RequestDriverStatus(driverId) => { + if (state != RecoveryState.ALIVE) { + val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + + "Can only request driver status in ALIVE state." + context.reply( + DriverStatusResponse(found = false, None, None, None, Some(new Exception(msg)))) + } else { + (drivers ++ completedDrivers).find(_.id == driverId) match { + case Some(driver) => + context.reply(DriverStatusResponse(found = true, Some(driver.state), + driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception)) + case None => + context.reply(DriverStatusResponse(found = false, None, None, None, None)) + } + } + } + + case RequestMasterState => { + context.reply(MasterStateResponse( + address.host, address.port, restServerBoundPort, + workers.toArray, apps.toArray, completedApps.toArray, + drivers.toArray, completedDrivers.toArray, state)) } case BoundPortsRequest => { - sender ! BoundPortsResponse(port, webUi.boundPort, restServerBoundPort) + context.reply(BoundPortsResponse(address.port, webUi.boundPort, restServerBoundPort)) } } + override def onDisconnected(address: RpcAddress): Unit = { + // The disconnected client could've been either a worker or an app; remove whichever it was + logInfo(s"$address got disassociated, removing it.") + addressToWorker.get(address).foreach(removeWorker) + addressToApp.get(address).foreach(finishApplication) + if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } + } + private def canCompleteRecovery = workers.count(_.state == WorkerState.UNKNOWN) == 0 && apps.count(_.state == ApplicationState.UNKNOWN) == 0 @@ -479,7 +493,7 @@ private[master] class Master( try { registerApplication(app) app.state = ApplicationState.UNKNOWN - app.driver ! MasterChanged(masterUrl, masterWebUiUrl) + app.driver.send(MasterChanged(self, masterWebUiUrl)) } catch { case e: Exception => logInfo("App " + app.id + " had exception on reconnect") } @@ -496,7 +510,7 @@ private[master] class Master( try { registerWorker(worker) worker.state = WorkerState.UNKNOWN - worker.actor ! MasterChanged(masterUrl, masterWebUiUrl) + worker.endpoint.send(MasterChanged(self, masterWebUiUrl)) } catch { case e: Exception => logInfo("Worker " + worker.id + " had exception on reconnect") } @@ -504,6 +518,7 @@ private[master] class Master( } private def completeRecovery() { + // TODO Why synchronized // Ensure "only-once" recovery semantics using a short synchronization period. synchronized { if (state != RecoveryState.RECOVERING) { return } @@ -623,10 +638,10 @@ private[master] class Master( private def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc): Unit = { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(masterUrl, - exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory) - exec.application.driver ! ExecutorAdded( - exec.id, worker.id, worker.hostPort, exec.cores, exec.memory) + worker.endpoint.send(LaunchExecutor(masterUrl, + exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory)) + exec.application.driver.send(ExecutorAdded( + exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)) } private def registerWorker(worker: WorkerInfo): Boolean = { @@ -638,7 +653,7 @@ private[master] class Master( workers -= w } - val workerAddress = worker.actor.path.address + val workerAddress = worker.endpoint.address if (addressToWorker.contains(workerAddress)) { val oldWorker = addressToWorker(workerAddress) if (oldWorker.state == WorkerState.UNKNOWN) { @@ -661,11 +676,11 @@ private[master] class Master( logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) worker.setState(WorkerState.DEAD) idToWorker -= worker.id - addressToWorker -= worker.actor.path.address + addressToWorker -= worker.endpoint.address for (exec <- worker.executors.values) { logInfo("Telling app of lost executor: " + exec.id) - exec.application.driver ! ExecutorUpdated( - exec.id, ExecutorState.LOST, Some("worker lost"), None) + exec.application.driver.send(ExecutorUpdated( + exec.id, ExecutorState.LOST, Some("worker lost"), None)) exec.application.removeExecutor(exec) } for (driver <- worker.drivers.values) { @@ -687,14 +702,15 @@ private[master] class Master( schedule() } - private def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { + private def createApplication(desc: ApplicationDescription, driver: RpcEndpointRef): + ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) new ApplicationInfo(now, newApplicationId(date), desc, date, driver, defaultCores) } private def registerApplication(app: ApplicationInfo): Unit = { - val appAddress = app.driver.path.address + val appAddress = app.driver.address if (addressToApp.contains(appAddress)) { logInfo("Attempted to re-register application at same address: " + appAddress) return @@ -703,7 +719,7 @@ private[master] class Master( applicationMetricsSystem.registerSource(app.appSource) apps += app idToApp(app.id) = app - actorToApp(app.driver) = app + endpointToApp(app.driver) = app addressToApp(appAddress) = app waitingApps += app } @@ -717,8 +733,8 @@ private[master] class Master( logInfo("Removing app " + app.id) apps -= app idToApp -= app.id - actorToApp -= app.driver - addressToApp -= app.driver.path.address + endpointToApp -= app.driver + addressToApp -= app.driver.address if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) completedApps.take(toRemove).foreach( a => { @@ -735,19 +751,19 @@ private[master] class Master( for (exec <- app.executors.values) { exec.worker.removeExecutor(exec) - exec.worker.actor ! KillExecutor(masterUrl, exec.application.id, exec.id) + exec.worker.endpoint.send(KillExecutor(masterUrl, exec.application.id, exec.id)) exec.state = ExecutorState.KILLED } app.markFinished(state) if (state != ApplicationState.FINISHED) { - app.driver ! ApplicationRemoved(state.toString) + app.driver.send(ApplicationRemoved(state.toString)) } persistenceEngine.removeApplication(app) schedule() // Tell all workers that the application has finished, so they can clean up any app state. workers.foreach { w => - w.actor ! ApplicationFinished(app.id) + w.endpoint.send(ApplicationFinished(app.id)) } } } @@ -768,7 +784,7 @@ private[master] class Master( } val eventLogFilePrefix = EventLoggingListener.getLogPath( - eventLogDir, app.id, None, app.desc.eventLogCodec) + eventLogDir, app.id, app.desc.eventLogCodec) val fs = Utils.getHadoopFileSystem(eventLogDir, hadoopConf) val inProgressExists = fs.exists(new Path(eventLogFilePrefix + EventLoggingListener.IN_PROGRESS)) @@ -832,14 +848,14 @@ private[master] class Master( private def timeOutDeadWorkers() { // Copy the workers into an array so we don't modify the hashset while iterating through it val currentTime = System.currentTimeMillis() - val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT).toArray + val toRemove = workers.filter(_.lastHeartbeat < currentTime - WORKER_TIMEOUT_MS).toArray for (worker <- toRemove) { if (worker.state != WorkerState.DEAD) { logWarning("Removing %s because we got no heartbeat in %d seconds".format( - worker.id, WORKER_TIMEOUT/1000)) + worker.id, WORKER_TIMEOUT_MS / 1000)) removeWorker(worker) } else { - if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT)) { + if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT_MS)) { workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it } } @@ -862,7 +878,7 @@ private[master] class Master( logInfo("Launching driver " + driver.id + " on worker " + worker.id) worker.addDriver(driver) driver.worker = Some(worker) - worker.actor ! LaunchDriver(driver.id, driver.desc) + worker.endpoint.send(LaunchDriver(driver.id, driver.desc)) driver.state = DriverState.RUNNING } @@ -891,57 +907,33 @@ private[master] class Master( } private[deploy] object Master extends Logging { - val systemName = "sparkMaster" - private val actorName = "Master" + val SYSTEM_NAME = "sparkMaster" + val ENDPOINT_NAME = "Master" def main(argStrings: Array[String]) { SignalLogger.register(log) val conf = new SparkConf val args = new MasterArguments(argStrings, conf) - val (actorSystem, _, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) - actorSystem.awaitTermination() - } - - /** - * Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:port`. - * - * @throws SparkException if the url is invalid - */ - def toAkkaUrl(sparkUrl: String, protocol: String): String = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - AkkaUtils.address(protocol, systemName, host, port, actorName) - } - - /** - * Returns an akka `Address` for the Master actor given a sparkUrl `spark://host:port`. - * - * @throws SparkException if the url is invalid - */ - def toAkkaAddress(sparkUrl: String, protocol: String): Address = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - Address(protocol, systemName, host, port) + val (rpcEnv, _, _) = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, conf) + rpcEnv.awaitTermination() } /** - * Start the Master and return a four tuple of: - * (1) The Master actor system - * (2) The bound port - * (3) The web UI bound port - * (4) The REST server bound port, if any + * Start the Master and return a three tuple of: + * (1) The Master RpcEnv + * (2) The web UI bound port + * (3) The REST server bound port, if any */ - def startSystemAndActor( + def startRpcEnvAndEndpoint( host: String, port: Int, webUiPort: Int, - conf: SparkConf): (ActorSystem, Int, Int, Option[Int]) = { + conf: SparkConf): (RpcEnv, Int, Option[Int]) = { val securityMgr = new SecurityManager(conf) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, - securityManager = securityMgr) - val actor = actorSystem.actorOf( - Props(classOf[Master], host, boundPort, webUiPort, securityMgr, conf), actorName) - val timeout = RpcUtils.askTimeout(conf) - val portsRequest = actor.ask(BoundPortsRequest)(timeout) - val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse] - (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.restPort) + val rpcEnv = RpcEnv.create(SYSTEM_NAME, host, port, conf, securityMgr) + val masterEndpoint = rpcEnv.setupEndpoint(ENDPOINT_NAME, + new Master(rpcEnv, rpcEnv.address, webUiPort, securityMgr, conf)) + val portsResponse = masterEndpoint.askWithRetry[BoundPortsResponse](BoundPortsRequest) + (rpcEnv, portsResponse.webUIPort, portsResponse.restPort) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala index 15c6296888f70..68c937188b333 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala @@ -28,7 +28,7 @@ private[master] object MasterMessages { case object RevokedLeadership - // Actor System to Master + // Master to itself case object CheckForWorkerTimeOut diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index 9b3d48c6edc84..471811037e5e2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -19,9 +19,7 @@ package org.apache.spark.deploy.master import scala.collection.mutable -import akka.actor.ActorRef - -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] class WorkerInfo( @@ -30,7 +28,7 @@ private[spark] class WorkerInfo( val port: Int, val cores: Int, val memory: Int, - val actor: ActorRef, + val endpoint: RpcEndpointRef, val webUiPort: Int, val publicAddress: String) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 52758d6a7c4be..6fdff86f66e01 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -17,10 +17,7 @@ package org.apache.spark.deploy.master -import akka.actor.ActorRef - import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.deploy.master.MasterMessages._ import org.apache.curator.framework.CuratorFramework import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} import org.apache.spark.deploy.SparkCuratorUtil diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 06e265f99e231..e28e7e379ac91 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -19,11 +19,8 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask - import org.apache.spark.deploy.ExecutorState import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master.ExecutorDesc @@ -32,14 +29,12 @@ import org.apache.spark.util.Utils private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") { - private val master = parent.masterActorRef - private val timeout = parent.timeout + private val master = parent.masterEndpointRef /** Executor details for a particular application */ def render(request: HttpServletRequest): Seq[Node] = { val appId = request.getParameter("appId") - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) + val state = master.askWithRetry[MasterStateResponse](RequestMasterState) val app = state.activeApps.find(_.id == appId).getOrElse({ state.completedApps.find(_.id == appId).getOrElse(null) }) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 6a7c74020bace..c3e20ebf8d6eb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -19,25 +19,21 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import org.json4s.JValue import org.apache.spark.deploy.JsonProtocol -import org.apache.spark.deploy.DeployMessages.{RequestKillDriver, MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver, MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master._ import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { - private val master = parent.masterActorRef - private val timeout = parent.timeout + private val master = parent.masterEndpointRef def getMasterState: MasterStateResponse = { - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - Await.result(stateFuture, timeout) + master.askWithRetry[MasterStateResponse](RequestMasterState) } override def renderJson(request: HttpServletRequest): JValue = { @@ -53,7 +49,9 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } def handleDriverKillRequest(request: HttpServletRequest): Unit = { - handleKillRequest(request, id => { master ! RequestKillDriver(id) }) + handleKillRequest(request, id => { + master.ask[KillDriverResponse](RequestKillDriver(id)) + }) } private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 2111a8581f2e4..6174fc11f83d8 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -23,7 +23,6 @@ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationsListResource UIRoot} import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.RpcUtils /** * Web UI server for the standalone master. @@ -33,8 +32,7 @@ class MasterWebUI(val master: Master, requestedPort: Int) extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging with UIRoot { - val masterActorRef = master.self - val timeout = RpcUtils.askTimeout(master.conf) + val masterEndpointRef = master.self val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true) val masterPage = new MasterPage(this) diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 502b9bb701ccf..d5b9bcab1423f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -20,10 +20,10 @@ package org.apache.spark.deploy.rest import java.io.File import javax.servlet.http.HttpServletResponse -import akka.actor.ActorRef import org.apache.spark.deploy.ClientArguments._ import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} -import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.util.Utils import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} /** @@ -45,35 +45,34 @@ import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} * @param host the address this server should bind to * @param requestedPort the port this server will attempt to bind to * @param masterConf the conf used by the Master - * @param masterActor reference to the Master actor to which requests can be sent + * @param masterEndpoint reference to the Master endpoint to which requests can be sent * @param masterUrl the URL of the Master new drivers will attempt to connect to */ private[deploy] class StandaloneRestServer( host: String, requestedPort: Int, masterConf: SparkConf, - masterActor: ActorRef, + masterEndpoint: RpcEndpointRef, masterUrl: String) extends RestSubmissionServer(host, requestedPort, masterConf) { protected override val submitRequestServlet = - new StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) + new StandaloneSubmitRequestServlet(masterEndpoint, masterUrl, masterConf) protected override val killRequestServlet = - new StandaloneKillRequestServlet(masterActor, masterConf) + new StandaloneKillRequestServlet(masterEndpoint, masterConf) protected override val statusRequestServlet = - new StandaloneStatusRequestServlet(masterActor, masterConf) + new StandaloneStatusRequestServlet(masterEndpoint, masterConf) } /** * A servlet for handling kill requests passed to the [[StandaloneRestServer]]. */ -private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: SparkConf) +private[rest] class StandaloneKillRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf) extends KillRequestServlet { protected def handleKill(submissionId: String): KillSubmissionResponse = { - val askTimeout = RpcUtils.askTimeout(conf) - val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse]( - DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout) + val response = masterEndpoint.askWithRetry[DeployMessages.KillDriverResponse]( + DeployMessages.RequestKillDriver(submissionId)) val k = new KillSubmissionResponse k.serverSparkVersion = sparkVersion k.message = response.message @@ -86,13 +85,12 @@ private[rest] class StandaloneKillRequestServlet(masterActor: ActorRef, conf: Sp /** * A servlet for handling status requests passed to the [[StandaloneRestServer]]. */ -private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: SparkConf) +private[rest] class StandaloneStatusRequestServlet(masterEndpoint: RpcEndpointRef, conf: SparkConf) extends StatusRequestServlet { protected def handleStatus(submissionId: String): SubmissionStatusResponse = { - val askTimeout = RpcUtils.askTimeout(conf) - val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse]( - DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout) + val response = masterEndpoint.askWithRetry[DeployMessages.DriverStatusResponse]( + DeployMessages.RequestDriverStatus(submissionId)) val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) } val d = new SubmissionStatusResponse d.serverSparkVersion = sparkVersion @@ -110,7 +108,7 @@ private[rest] class StandaloneStatusRequestServlet(masterActor: ActorRef, conf: * A servlet for handling submit requests passed to the [[StandaloneRestServer]]. */ private[rest] class StandaloneSubmitRequestServlet( - masterActor: ActorRef, + masterEndpoint: RpcEndpointRef, masterUrl: String, conf: SparkConf) extends SubmitRequestServlet { @@ -175,10 +173,9 @@ private[rest] class StandaloneSubmitRequestServlet( responseServlet: HttpServletResponse): SubmitRestProtocolResponse = { requestMessage match { case submitRequest: CreateSubmissionRequest => - val askTimeout = RpcUtils.askTimeout(conf) val driverDescription = buildDriverDescription(submitRequest) - val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse]( - DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout) + val response = masterEndpoint.askWithRetry[DeployMessages.SubmitDriverResponse]( + DeployMessages.RequestSubmitDriver(driverDescription)) val submitResponse = new CreateSubmissionResponse submitResponse.serverSparkVersion = sparkVersion submitResponse.message = response.message diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 1386055eb8c48..ec51c3d935d8e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -21,7 +21,6 @@ import java.io._ import scala.collection.JavaConversions._ -import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.hadoop.fs.Path @@ -31,6 +30,7 @@ import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages.DriverStateChanged import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.master.DriverState.DriverState +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{Utils, Clock, SystemClock} /** @@ -43,7 +43,7 @@ private[deploy] class DriverRunner( val workDir: File, val sparkHome: File, val driverDesc: DriverDescription, - val worker: ActorRef, + val worker: RpcEndpointRef, val workerUrl: String, val securityManager: SecurityManager) extends Logging { @@ -107,7 +107,7 @@ private[deploy] class DriverRunner( finalState = Some(state) - worker ! DriverStateChanged(driverId, state, finalException) + worker.send(DriverStateChanged(driverId, state, finalException)) } }.start() } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index fff17e1095042..29a5042285578 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -21,10 +21,10 @@ import java.io._ import scala.collection.JavaConversions._ -import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged @@ -41,7 +41,7 @@ private[deploy] class ExecutorRunner( val appDesc: ApplicationDescription, val cores: Int, val memory: Int, - val worker: ActorRef, + val worker: RpcEndpointRef, val workerId: String, val host: String, val webUiPort: Int, @@ -91,7 +91,7 @@ private[deploy] class ExecutorRunner( process.destroy() exitCode = Some(process.waitFor()) } - worker ! ExecutorStateChanged(appId, execId, state, message, exitCode) + worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) } /** Stop this executor runner, including killing the process it launched */ @@ -159,7 +159,7 @@ private[deploy] class ExecutorRunner( val exitCode = process.waitFor() state = ExecutorState.EXITED val message = "Command exited with code " + exitCode - worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode)) + worker.send(ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))) } catch { case interrupted: InterruptedException => { logInfo("Runner thread for executor " + fullId + " interrupted") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index ebc6cd76c6afd..82e9578bbcba5 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -21,15 +21,14 @@ import java.io.File import java.io.IOException import java.text.SimpleDateFormat import java.util.{UUID, Date} +import java.util.concurrent._ +import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} -import scala.concurrent.duration._ -import scala.language.postfixOps +import scala.concurrent.ExecutionContext import scala.util.Random - -import akka.actor._ -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import scala.util.control.NonFatal import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState} @@ -38,32 +37,39 @@ import org.apache.spark.deploy.ExternalShuffleService import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.rpc._ +import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} -/** - * @param masterAkkaUrls Each url should be a valid akka url. - */ private[worker] class Worker( - host: String, - port: Int, + override val rpcEnv: RpcEnv, webUiPort: Int, cores: Int, memory: Int, - masterAkkaUrls: Array[String], - actorSystemName: String, - actorName: String, + masterRpcAddresses: Array[RpcAddress], + systemName: String, + endpointName: String, workDirPath: String = null, val conf: SparkConf, val securityMgr: SecurityManager) - extends Actor with ActorLogReceive with Logging { - import context.dispatcher + extends ThreadSafeRpcEndpoint with Logging { + + private val host = rpcEnv.address.host + private val port = rpcEnv.address.port Utils.checkHost(host, "Expected hostname") assert (port > 0) + // A scheduled executor used to send messages at the specified time. + private val forwordMessageScheduler = + ThreadUtils.newDaemonSingleThreadScheduledExecutor("worker-forward-message-scheduler") + + // A separated thread to clean up the workDir. Used to provide the implicit parameter of `Future` + // methods. + private val cleanupThreadExecutor = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor("worker-cleanup-thread")) + // For worker and executor IDs private def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") - // Send a heartbeat every (heartbeat timeout) / 4 milliseconds private val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4 @@ -79,32 +85,26 @@ private[worker] class Worker( val randomNumberGenerator = new Random(UUID.randomUUID.getMostSignificantBits) randomNumberGenerator.nextDouble + FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND } - private val INITIAL_REGISTRATION_RETRY_INTERVAL = (math.round(10 * - REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds - private val PROLONGED_REGISTRATION_RETRY_INTERVAL = (math.round(60 - * REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds + private val INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS = (math.round(10 * + REGISTRATION_RETRY_FUZZ_MULTIPLIER)) + private val PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS = (math.round(60 + * REGISTRATION_RETRY_FUZZ_MULTIPLIER)) private val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false) // How often worker will clean up old app folders private val CLEANUP_INTERVAL_MILLIS = conf.getLong("spark.worker.cleanup.interval", 60 * 30) * 1000 // TTL for app folders/data; after TTL expires it will be cleaned up - private val APP_DATA_RETENTION_SECS = + private val APP_DATA_RETENTION_SECONDS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) private val testing: Boolean = sys.props.contains("spark.testing") - private var master: ActorSelection = null - private var masterAddress: Address = null + private var master: Option[RpcEndpointRef] = None private var activeMasterUrl: String = "" private[worker] var activeMasterWebUiUrl : String = "" - private val akkaUrl = AkkaUtils.address( - AkkaUtils.protocol(context.system), - actorSystemName, - host, - port, - actorName) - @volatile private var registered = false - @volatile private var connected = false + private val workerUri = rpcEnv.uriOf(systemName, rpcEnv.address, endpointName) + private var registered = false + private var connected = false private val workerId = generateWorkerId() private val sparkHome = if (testing) { @@ -136,7 +136,18 @@ private[worker] class Worker( private val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) private val workerSource = new WorkerSource(this) - private var registrationRetryTimer: Option[Cancellable] = None + private var registerMasterFutures: Array[JFuture[_]] = null + private var registrationRetryTimer: Option[JScheduledFuture[_]] = None + + // A thread pool for registering with masters. Because registering with a master is a blocking + // action, this thread pool must be able to create "masterRpcAddresses.size" threads at the same + // time so that we can register with all masters. + private val registerMasterThreadPool = new ThreadPoolExecutor( + 0, + masterRpcAddresses.size, // Make sure we can register with all masters at the same time + 60L, TimeUnit.SECONDS, + new SynchronousQueue[Runnable](), + ThreadUtils.namedThreadFactory("worker-register-master-threadpool")) var coresUsed = 0 var memoryUsed = 0 @@ -162,14 +173,13 @@ private[worker] class Worker( } } - override def preStart() { + override def onStart() { assert(!registered) logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( host, port, cores, Utils.megabytesToString(memory))) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") logInfo("Spark home: " + sparkHome) createWorkDir() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) shuffleService.startIfEnabled() webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() @@ -181,24 +191,32 @@ private[worker] class Worker( metricsSystem.getServletHandlers.foreach(webUi.attachHandler) } - private def changeMaster(url: String, uiUrl: String) { + private def changeMaster(masterRef: RpcEndpointRef, uiUrl: String) { // activeMasterUrl it's a valid Spark url since we receive it from master. - activeMasterUrl = url + activeMasterUrl = masterRef.address.toSparkURL activeMasterWebUiUrl = uiUrl - master = context.actorSelection( - Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(context.system))) - masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(context.system)) + master = Some(masterRef) connected = true // Cancel any outstanding re-registration attempts because we found a new master - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = None + cancelLastRegistrationRetry() } - private def tryRegisterAllMasters() { - for (masterAkkaUrl <- masterAkkaUrls) { - logInfo("Connecting to master " + masterAkkaUrl + "...") - val actor = context.actorSelection(masterAkkaUrl) - actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort, publicAddress) + private def tryRegisterAllMasters(): Array[JFuture[_]] = { + masterRpcAddresses.map { masterAddress => + registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = { + try { + logInfo("Connecting to master " + masterAddress + "...") + val masterEndpoint = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterEndpoint.send(RegisterWorker( + workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + } + }) } } @@ -211,8 +229,7 @@ private[worker] class Worker( Utils.tryOrExit { connectionAttemptCount += 1 if (registered) { - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = None + cancelLastRegistrationRetry() } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) { logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)") /** @@ -235,21 +252,48 @@ private[worker] class Worker( * still not safe if the old master recovers within this interval, but this is a much * less likely scenario. */ - if (master != null) { - master ! RegisterWorker( - workerId, host, port, cores, memory, webUi.boundPort, publicAddress) - } else { - // We are retrying the initial registration - tryRegisterAllMasters() + master match { + case Some(masterRef) => + // registered == false && master != None means we lost the connection to master, so + // masterRef cannot be used and we need to recreate it again. Note: we must not set + // master to None due to the above comments. + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + } + val masterAddress = masterRef.address + registerMasterFutures = Array(registerMasterThreadPool.submit(new Runnable { + override def run(): Unit = { + try { + logInfo("Connecting to master " + masterAddress + "...") + val masterEndpoint = + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + masterEndpoint.send(RegisterWorker( + workerId, host, port, self, cores, memory, webUi.boundPort, publicAddress)) + } catch { + case ie: InterruptedException => // Cancelled + case NonFatal(e) => logWarning(s"Failed to connect to master $masterAddress", e) + } + } + })) + case None => + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + } + // We are retrying the initial registration + registerMasterFutures = tryRegisterAllMasters() } // We have exceeded the initial registration retry threshold // All retries from now on should use a higher interval if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) { - registrationRetryTimer.foreach(_.cancel()) - registrationRetryTimer = Some { - context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL, - PROLONGED_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) - } + registrationRetryTimer.foreach(_.cancel(true)) + registrationRetryTimer = Some( + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(ReregisterWithMaster) + } + }, PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS, + PROLONGED_REGISTRATION_RETRY_INTERVAL_SECONDS, + TimeUnit.SECONDS)) } } else { logError("All masters are unresponsive! Giving up.") @@ -258,41 +302,67 @@ private[worker] class Worker( } } + /** + * Cancel last registeration retry, or do nothing if no retry + */ + private def cancelLastRegistrationRetry(): Unit = { + if (registerMasterFutures != null) { + registerMasterFutures.foreach(_.cancel(true)) + registerMasterFutures = null + } + registrationRetryTimer.foreach(_.cancel(true)) + registrationRetryTimer = None + } + private def registerWithMaster() { - // DisassociatedEvent may be triggered multiple times, so don't attempt registration + // onDisconnected may be triggered multiple times, so don't attempt registration // if there are outstanding registration attempts scheduled. registrationRetryTimer match { case None => registered = false - tryRegisterAllMasters() + registerMasterFutures = tryRegisterAllMasters() connectionAttemptCount = 0 - registrationRetryTimer = Some { - context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL, - INITIAL_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) - } + registrationRetryTimer = Some(forwordMessageScheduler.scheduleAtFixedRate( + new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(ReregisterWithMaster) + } + }, + INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, + INITIAL_REGISTRATION_RETRY_INTERVAL_SECONDS, + TimeUnit.SECONDS)) case Some(_) => logInfo("Not spawning another attempt to register with the master, since there is an" + " attempt scheduled already.") } } - override def receiveWithLogging: PartialFunction[Any, Unit] = { - case RegisteredWorker(masterUrl, masterWebUiUrl) => - logInfo("Successfully registered with master " + masterUrl) + override def receive: PartialFunction[Any, Unit] = { + case RegisteredWorker(masterRef, masterWebUiUrl) => + logInfo("Successfully registered with master " + masterRef.address.toSparkURL) registered = true - changeMaster(masterUrl, masterWebUiUrl) - context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat) + changeMaster(masterRef, masterWebUiUrl) + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(SendHeartbeat) + } + }, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS) if (CLEANUP_ENABLED) { logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir") - context.system.scheduler.schedule(CLEANUP_INTERVAL_MILLIS millis, - CLEANUP_INTERVAL_MILLIS millis, self, WorkDirCleanup) + forwordMessageScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryLogNonFatalError { + self.send(WorkDirCleanup) + } + }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) } case SendHeartbeat => - if (connected) { master ! Heartbeat(workerId) } + if (connected) { sendToMaster(Heartbeat(workerId, self)) } case WorkDirCleanup => // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor + // Copy ids so that it can be used in the cleanup thread. + val appIds = executors.values.map(_.appId).toSet val cleanupFuture = concurrent.future { val appDirs = workDir.listFiles() if (appDirs == null) { @@ -302,27 +372,27 @@ private[worker] class Worker( // the directory is used by an application - check that the application is not running // when cleaning up val appIdFromDir = dir.getName - val isAppStillRunning = executors.values.map(_.appId).contains(appIdFromDir) + val isAppStillRunning = appIds.contains(appIdFromDir) dir.isDirectory && !isAppStillRunning && - !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECS) + !Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECONDS) }.foreach { dir => logInfo(s"Removing directory: ${dir.getPath}") Utils.deleteRecursively(dir) } - } + }(cleanupThreadExecutor) - cleanupFuture onFailure { + cleanupFuture.onFailure { case e: Throwable => logError("App dir cleanup failed: " + e.getMessage, e) - } + }(cleanupThreadExecutor) - case MasterChanged(masterUrl, masterWebUiUrl) => - logInfo("Master has changed, new master is at " + masterUrl) - changeMaster(masterUrl, masterWebUiUrl) + case MasterChanged(masterRef, masterWebUiUrl) => + logInfo("Master has changed, new master is at " + masterRef.address.toSparkURL) + changeMaster(masterRef, masterWebUiUrl) val execs = executors.values. map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) - sender ! WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq) + masterRef.send(WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq)) case RegisterWorkerFailed(message) => if (!registered) { @@ -369,14 +439,14 @@ private[worker] class Worker( publicAddress, sparkHome, executorDir, - akkaUrl, + workerUri, conf, appLocalDirs, ExecutorState.LOADING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ memoryUsed += memory_ - master ! ExecutorStateChanged(appId, execId, manager.state, None, None) + sendToMaster(ExecutorStateChanged(appId, execId, manager.state, None, None)) } catch { case e: Exception => { logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e) @@ -384,14 +454,14 @@ private[worker] class Worker( executors(appId + "/" + execId).kill() executors -= appId + "/" + execId } - master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, - Some(e.toString), None) + sendToMaster(ExecutorStateChanged(appId, execId, ExecutorState.FAILED, + Some(e.toString), None)) } } } - case ExecutorStateChanged(appId, execId, state, message, exitStatus) => - master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) + case executorStateChanged @ ExecutorStateChanged(appId, execId, state, message, exitStatus) => + sendToMaster(executorStateChanged) val fullId = appId + "/" + execId if (ExecutorState.isFinished(state)) { executors.get(fullId) match { @@ -434,7 +504,7 @@ private[worker] class Worker( sparkHome, driverDesc.copy(command = Worker.maybeUpdateSSLSettings(driverDesc.command, conf)), self, - akkaUrl, + workerUri, securityMgr) drivers(driverId) = driver driver.start() @@ -453,7 +523,7 @@ private[worker] class Worker( } } - case DriverStateChanged(driverId, state, exception) => { + case driverStageChanged @ DriverStateChanged(driverId, state, exception) => { state match { case DriverState.ERROR => logWarning(s"Driver $driverId failed with unrecoverable exception: ${exception.get}") @@ -466,23 +536,13 @@ private[worker] class Worker( case _ => logDebug(s"Driver $driverId changed state to $state") } - master ! DriverStateChanged(driverId, state, exception) + sendToMaster(driverStageChanged) val driver = drivers.remove(driverId).get finishedDrivers(driverId) = driver memoryUsed -= driver.driverDesc.mem coresUsed -= driver.driverDesc.cores } - case x: DisassociatedEvent if x.remoteAddress == masterAddress => - logInfo(s"$x Disassociated !") - masterDisconnected() - - case RequestWorkerState => - sender ! WorkerStateResponse(host, port, workerId, executors.values.toList, - finishedExecutors.values.toList, drivers.values.toList, - finishedDrivers.values.toList, activeMasterUrl, cores, memory, - coresUsed, memoryUsed, activeMasterWebUiUrl) - case ReregisterWithMaster => reregisterWithMaster() @@ -491,6 +551,21 @@ private[worker] class Worker( maybeCleanupApplication(id) } + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestWorkerState => + context.reply(WorkerStateResponse(host, port, workerId, executors.values.toList, + finishedExecutors.values.toList, drivers.values.toList, + finishedDrivers.values.toList, activeMasterUrl, cores, memory, + coresUsed, memoryUsed, activeMasterWebUiUrl)) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (master.exists(_.address == remoteAddress)) { + logInfo(s"$remoteAddress Disassociated !") + masterDisconnected() + } + } + private def masterDisconnected() { logError("Connection to master failed! Waiting for master to reconnect...") connected = false @@ -510,13 +585,29 @@ private[worker] class Worker( } } + /** + * Send a message to the current master. If we have not yet registered successfully with any + * master, the message will be dropped. + */ + private def sendToMaster(message: Any): Unit = { + master match { + case Some(masterRef) => masterRef.send(message) + case None => + logWarning( + s"Dropping $message because the connection to master has not yet been established") + } + } + private def generateWorkerId(): String = { "worker-%s-%s-%d".format(createDateFormat.format(new Date), host, port) } - override def postStop() { + override def onStop() { + cleanupThreadExecutor.shutdownNow() metricsSystem.report() - registrationRetryTimer.foreach(_.cancel()) + cancelLastRegistrationRetry() + forwordMessageScheduler.shutdownNow() + registerMasterThreadPool.shutdownNow() executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) shuffleService.stop() @@ -530,12 +621,12 @@ private[deploy] object Worker extends Logging { SignalLogger.register(log) val conf = new SparkConf val args = new WorkerArguments(argStrings, conf) - val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, + val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores, args.memory, args.masters, args.workDir) - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } - def startSystemAndActor( + def startRpcEnvAndEndpoint( host: String, port: Int, webUiPort: Int, @@ -544,18 +635,17 @@ private[deploy] object Worker extends Logging { masterUrls: Array[String], workDir: String, workerNumber: Option[Int] = None, - conf: SparkConf = new SparkConf): (ActorSystem, Int) = { + conf: SparkConf = new SparkConf): RpcEnv = { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val actorName = "Worker" val securityMgr = new SecurityManager(conf) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, - conf = conf, securityManager = securityMgr) - val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem))) - actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) - (actorSystem, boundPort) + val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr) + val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL(_)) + rpcEnv.setupEndpoint(actorName, new Worker(rpcEnv, webUiPort, cores, memory, masterAddresses, + systemName, actorName, workDir, conf, securityMgr)) + rpcEnv } def isUseLocalNodeSSLConfig(cmd: Command): Boolean = { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 83fb991891a41..fae5640b9a213 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -18,7 +18,6 @@ package org.apache.spark.deploy.worker import org.apache.spark.Logging -import org.apache.spark.deploy.DeployMessages.SendHeartbeat import org.apache.spark.rpc._ /** diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 9f9f27d71e1ae..fd905feb97e92 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -17,10 +17,8 @@ package org.apache.spark.deploy.worker.ui -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import javax.servlet.http.HttpServletRequest import org.json4s.JValue @@ -32,18 +30,15 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { - private val workerActor = parent.worker.self - private val timeout = parent.timeout + private val workerEndpoint = parent.worker.self override def renderJson(request: HttpServletRequest): JValue = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState) JsonProtocol.writeWorkerState(workerState) } def render(request: HttpServletRequest): Seq[Node] = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val workerState = workerEndpoint.askWithRetry[WorkerStateResponse](RequestWorkerState) val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs") val runningExecutors = workerState.executors diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 12b6b28d4d7ec..3b6938ec639c3 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -158,6 +158,8 @@ private[spark] case class RpcAddress(host: String, port: Int) { val hostPort: String = host + ":" + port override val toString: String = hostPort + + def toSparkURL: String = "spark://" + hostPort } diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 0161962cde073..31ebe5ac5bca3 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -180,10 +180,10 @@ private[spark] class AkkaRpcEnv private[akka] ( }) } catch { case NonFatal(e) => - if (needReply) { - // If the sender asks a reply, we should send the error back to the sender - _sender ! AkkaFailure(e) - } else { + _sender ! AkkaFailure(e) + if (!needReply) { + // If the sender does not require a reply, it may not handle the exception. So we rethrow + // "e" to make sure it will be processed. throw e } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index ccf1dc5af6120..687ae9620460f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -85,7 +85,7 @@ private[spark] class SparkDeploySchedulerBackend( val coresPerExecutor = conf.getOption("spark.executor.cores").map(_.toInt) val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, appUIAddress, sc.eventLogDir, sc.eventLogCodec, coresPerExecutor) - client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) + client = new AppClient(sc.env.rpcEnv, masters, appDesc, this, conf) client.start() waitForRegistration() } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 014e87bb40254..9cb6dd43bac47 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -19,63 +19,21 @@ package org.apache.spark.deploy.master import java.util.Date -import scala.concurrent.Await import scala.concurrent.duration._ import scala.io.Source import scala.language.postfixOps -import akka.actor.Address import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.scalatest.Matchers import org.scalatest.concurrent.Eventually import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} -import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.deploy._ class MasterSuite extends SparkFunSuite with Matchers with Eventually { - test("toAkkaUrl") { - val conf = new SparkConf(loadDefaults = false) - val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", "akka.tcp") - assert("akka.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl) - } - - test("toAkkaUrl with SSL") { - val conf = new SparkConf(loadDefaults = false) - val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234", "akka.ssl.tcp") - assert("akka.ssl.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl) - } - - test("toAkkaUrl: a typo url") { - val conf = new SparkConf(loadDefaults = false) - val e = intercept[SparkException] { - Master.toAkkaUrl("spark://1.2. 3.4:1234", "akka.tcp") - } - assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) - } - - test("toAkkaAddress") { - val conf = new SparkConf(loadDefaults = false) - val address = Master.toAkkaAddress("spark://1.2.3.4:1234", "akka.tcp") - assert(Address("akka.tcp", "sparkMaster", "1.2.3.4", 1234) === address) - } - - test("toAkkaAddress with SSL") { - val conf = new SparkConf(loadDefaults = false) - val address = Master.toAkkaAddress("spark://1.2.3.4:1234", "akka.ssl.tcp") - assert(Address("akka.ssl.tcp", "sparkMaster", "1.2.3.4", 1234) === address) - } - - test("toAkkaAddress: a typo url") { - val conf = new SparkConf(loadDefaults = false) - val e = intercept[SparkException] { - Master.toAkkaAddress("spark://1.2. 3.4:1234", "akka.tcp") - } - assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) - } - test("can use a custom recovery mode factory") { val conf = new SparkConf(loadDefaults = false) conf.set("spark.deploy.recoveryMode", "CUSTOM") @@ -129,16 +87,16 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually { port = 10000, cores = 0, memory = 0, - actor = null, + endpoint = null, webUiPort = 0, publicAddress = "" ) - val (actorSystem, port, uiPort, restPort) = - Master.startSystemAndActor("127.0.0.1", 7077, 8080, conf) + val (rpcEnv, uiPort, restPort) = + Master.startRpcEnvAndEndpoint("127.0.0.1", 7077, 8080, conf) try { - Await.result(actorSystem.actorSelection("/user/Master").resolveOne(10 seconds), 10 seconds) + rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, rpcEnv.address, Master.ENDPOINT_NAME) CustomPersistenceEngine.lastInstance.isDefined shouldBe true val persistenceEngine = CustomPersistenceEngine.lastInstance.get @@ -154,8 +112,8 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually { workers.map(_.id) should contain(workerToPersist.id) } finally { - actorSystem.shutdown() - actorSystem.awaitTermination() + rpcEnv.shutdown() + rpcEnv.awaitTermination() } CustomRecoveryModeFactory.instantiationAttempts should be > instantiationAttempts diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 197f68e7ec5ed..96e456d889ac3 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -23,14 +23,14 @@ import javax.servlet.http.HttpServletResponse import scala.collection.mutable -import akka.actor.{Actor, ActorRef, ActorSystem, Props} import com.google.common.base.Charsets import org.scalatest.BeforeAndAfterEach import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ import org.apache.spark._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.rpc._ +import org.apache.spark.util.Utils import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments} import org.apache.spark.deploy.master.DriverState._ @@ -39,11 +39,11 @@ import org.apache.spark.deploy.master.DriverState._ * Tests for the REST application submission protocol used in standalone cluster mode. */ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { - private var actorSystem: Option[ActorSystem] = None + private var rpcEnv: Option[RpcEnv] = None private var server: Option[RestSubmissionServer] = None override def afterEach() { - actorSystem.foreach(_.shutdown()) + rpcEnv.foreach(_.shutdown()) server.foreach(_.stop()) } @@ -377,31 +377,32 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { killMessage: String = "driver is killed", state: DriverState = FINISHED, exception: Option[Exception] = None): String = { - startServer(new DummyMaster(submitId, submitMessage, killMessage, state, exception)) + startServer(new DummyMaster(_, submitId, submitMessage, killMessage, state, exception)) } /** Start a smarter dummy server that keeps track of submitted driver states. */ private def startSmartServer(): String = { - startServer(new SmarterMaster) + startServer(new SmarterMaster(_)) } /** Start a dummy server that is faulty in many ways... */ private def startFaultyServer(): String = { - startServer(new DummyMaster, faulty = true) + startServer(new DummyMaster(_), faulty = true) } /** - * Start a [[StandaloneRestServer]] that communicates with the given actor. + * Start a [[StandaloneRestServer]] that communicates with the given endpoint. * If `faulty` is true, start an [[FaultyStandaloneRestServer]] instead. * Return the master URL that corresponds to the address of this server. */ - private def startServer(makeFakeMaster: => Actor, faulty: Boolean = false): String = { + private def startServer( + makeFakeMaster: RpcEnv => RpcEndpoint, faulty: Boolean = false): String = { val name = "test-standalone-rest-protocol" val conf = new SparkConf val localhost = Utils.localHostName() val securityManager = new SecurityManager(conf) - val (_actorSystem, _) = AkkaUtils.createActorSystem(name, localhost, 0, conf, securityManager) - val fakeMasterRef = _actorSystem.actorOf(Props(makeFakeMaster)) + val _rpcEnv = RpcEnv.create(name, localhost, 0, conf, securityManager) + val fakeMasterRef = _rpcEnv.setupEndpoint("fake-master", makeFakeMaster(_rpcEnv)) val _server = if (faulty) { new FaultyStandaloneRestServer(localhost, 0, conf, fakeMasterRef, "spark://fake:7077") @@ -410,7 +411,7 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { } val port = _server.start() // set these to clean them up after every test - actorSystem = Some(_actorSystem) + rpcEnv = Some(_rpcEnv) server = Some(_server) s"spark://$localhost:$port" } @@ -505,20 +506,21 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { * In all responses, the success parameter is always true. */ private class DummyMaster( + override val rpcEnv: RpcEnv, submitId: String = "fake-driver-id", submitMessage: String = "submitted", killMessage: String = "killed", state: DriverState = FINISHED, exception: Option[Exception] = None) - extends Actor { + extends RpcEndpoint { - override def receive: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestSubmitDriver(driverDesc) => - sender ! SubmitDriverResponse(success = true, Some(submitId), submitMessage) + context.reply(SubmitDriverResponse(self, success = true, Some(submitId), submitMessage)) case RequestKillDriver(driverId) => - sender ! KillDriverResponse(driverId, success = true, killMessage) + context.reply(KillDriverResponse(self, driverId, success = true, killMessage)) case RequestDriverStatus(driverId) => - sender ! DriverStatusResponse(found = true, Some(state), None, None, exception) + context.reply(DriverStatusResponse(found = true, Some(state), None, None, exception)) } } @@ -531,28 +533,28 @@ private class DummyMaster( * Submits are always successful while kills and status requests are successful only * if the driver was submitted in the past. */ -private class SmarterMaster extends Actor { +private class SmarterMaster(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { private var counter: Int = 0 private val submittedDrivers = new mutable.HashMap[String, DriverState] - override def receive: PartialFunction[Any, Unit] = { + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestSubmitDriver(driverDesc) => val driverId = s"driver-$counter" submittedDrivers(driverId) = RUNNING counter += 1 - sender ! SubmitDriverResponse(success = true, Some(driverId), "submitted") + context.reply(SubmitDriverResponse(self, success = true, Some(driverId), "submitted")) case RequestKillDriver(driverId) => val success = submittedDrivers.contains(driverId) if (success) { submittedDrivers(driverId) = KILLED } - sender ! KillDriverResponse(driverId, success, "killed") + context.reply(KillDriverResponse(self, driverId, success, "killed")) case RequestDriverStatus(driverId) => val found = submittedDrivers.contains(driverId) val state = submittedDrivers.get(driverId) - sender ! DriverStatusResponse(found, state, None, None, None) + context.reply(DriverStatusResponse(found, state, None, None, None)) } } @@ -568,7 +570,7 @@ private class FaultyStandaloneRestServer( host: String, requestedPort: Int, masterConf: SparkConf, - masterActor: ActorRef, + masterEndpoint: RpcEndpointRef, masterUrl: String) extends RestSubmissionServer(host, requestedPort, masterConf) { @@ -578,7 +580,7 @@ private class FaultyStandaloneRestServer( /** A faulty servlet that produces malformed responses. */ class MalformedSubmitServlet - extends StandaloneSubmitRequestServlet(masterActor, masterUrl, masterConf) { + extends StandaloneSubmitRequestServlet(masterEndpoint, masterUrl, masterConf) { protected override def sendResponse( responseMessage: SubmitRestProtocolResponse, responseServlet: HttpServletResponse): Unit = { @@ -588,7 +590,7 @@ private class FaultyStandaloneRestServer( } /** A faulty servlet that produces invalid responses. */ - class InvalidKillServlet extends StandaloneKillRequestServlet(masterActor, masterConf) { + class InvalidKillServlet extends StandaloneKillRequestServlet(masterEndpoint, masterConf) { protected override def handleKill(submissionId: String): KillSubmissionResponse = { val k = super.handleKill(submissionId) k.submissionId = null @@ -597,7 +599,7 @@ private class FaultyStandaloneRestServer( } /** A faulty status servlet that explodes. */ - class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterActor, masterConf) { + class ExplodingStatusServlet extends StandaloneStatusRequestServlet(masterEndpoint, masterConf) { private def explode: Int = 1 / 0 protected override def handleStatus(submissionId: String): SubmissionStatusResponse = { val s = super.handleStatus(submissionId) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index ac18f04a11475..cd24d79423316 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.deploy.worker -import akka.actor.AddressFromURIString import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.SecurityManager import org.apache.spark.rpc.{RpcAddress, RpcEnv} @@ -26,13 +25,11 @@ class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher shuts down on valid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" - val targetWorkerAddress = AddressFromURIString(targetWorkerUrl) + val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) - workerWatcher.onDisconnected( - RpcAddress(targetWorkerAddress.host.get, targetWorkerAddress.port.get)) + workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234)) assert(workerWatcher.isShutDown) rpcEnv.shutdown() } @@ -40,13 +37,13 @@ class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher stays alive on invalid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" - val otherAkkaURL = "akka://test@4.3.2.1:1234/user/OtherActor" - val otherAkkaAddress = AddressFromURIString(otherAkkaURL) + val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") + val otherAddress = "akka://test@4.3.2.1:1234/user/OtherActor" + val otherAkkaAddress = RpcAddress("4.3.2.1", 1234) val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) - workerWatcher.onDisconnected(RpcAddress(otherAkkaAddress.host.get, otherAkkaAddress.port.get)) + workerWatcher.onDisconnected(otherAkkaAddress) assert(!workerWatcher.isShutDown) rpcEnv.shutdown() } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala new file mode 100644 index 0000000000000..b3223ec61bf79 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/RpcAddressSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import org.apache.spark.{SparkException, SparkFunSuite} + +class RpcAddressSuite extends SparkFunSuite { + + test("hostPort") { + val address = RpcAddress("1.2.3.4", 1234) + assert(address.host == "1.2.3.4") + assert(address.port == 1234) + assert(address.hostPort == "1.2.3.4:1234") + } + + test("fromSparkURL") { + val address = RpcAddress.fromSparkURL("spark://1.2.3.4:1234") + assert(address.host == "1.2.3.4") + assert(address.port == 1234) + } + + test("fromSparkURL: a typo url") { + val e = intercept[SparkException] { + RpcAddress.fromSparkURL("spark://1.2. 3.4:1234") + } + assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) + } + + test("fromSparkURL: invalid scheme") { + val e = intercept[SparkException] { + RpcAddress.fromSparkURL("invalid://1.2.3.4:1234") + } + assert("Invalid master URL: invalid://1.2.3.4:1234" === e.getMessage) + } + + test("toSparkURL") { + val address = RpcAddress("1.2.3.4", 1234) + assert(address.toSparkURL == "spark://1.2.3.4:1234") + } +} diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index a33a83db7bc9e..4aa75c9230b2c 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.rpc.akka import org.apache.spark.rpc._ -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SSLSampleConfigs, SecurityManager, SparkConf} class AkkaRpcEnvSuite extends RpcEnvSuite { @@ -47,4 +47,22 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { } } + test("uriOf") { + val uri = env.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") + assert("akka.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) + } + + test("uriOf: ssl") { + val conf = SSLSampleConfigs.sparkSSLConfig() + val securityManager = new SecurityManager(conf) + val rpcEnv = new AkkaRpcEnvFactory().create( + RpcEnvConfig(conf, "test", "localhost", 12346, securityManager)) + try { + val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") + assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) + } finally { + rpcEnv.shutdown() + } + } + } From f457569886e9de9256ad269cb4a3d73a8918766d Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Tue, 30 Jun 2015 20:19:43 -0700 Subject: [PATCH 111/122] [SPARK-8471] [ML] Rename DiscreteCosineTransformer to DCT Rename DiscreteCosineTransformer and related classes to DCT. Author: Feynman Liang Closes #7138 from feynmanliang/dct-features and squashes the following commits: e547b3e [Feynman Liang] Fix renaming bug 9d5c9e4 [Feynman Liang] Lowercase JavaDCTSuite variable f9a8958 [Feynman Liang] Remove old files f8fe794 [Feynman Liang] Merge branch 'master' into dct-features 894d0b2 [Feynman Liang] Rename DiscreteCosineTransformer to DCT 433dbc7 [Feynman Liang] Test refactoring 91e9636 [Feynman Liang] Style guide and test helper refactor b5ac19c [Feynman Liang] Use Vector types, add Java test 530983a [Feynman Liang] Tests for other numeric datatypes 195d7aa [Feynman Liang] Implement support for arbitrary numeric types 95d4939 [Feynman Liang] Working DCT for 1D Doubles --- .../{DiscreteCosineTransformer.scala => DCT.scala} | 4 ++-- ...creteCosineTransformerSuite.java => JavaDCTSuite.java} | 8 ++++---- ...iscreteCosineTransformerSuite.scala => DCTSuite.scala} | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) rename mllib/src/main/scala/org/apache/spark/ml/feature/{DiscreteCosineTransformer.scala => DCT.scala} (95%) rename mllib/src/test/java/org/apache/spark/ml/feature/{JavaDiscreteCosineTransformerSuite.java => JavaDCTSuite.java} (90%) rename mllib/src/test/scala/org/apache/spark/ml/feature/{DiscreteCosineTransformerSuite.scala => DCTSuite.scala} (94%) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DiscreteCosineTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala similarity index 95% rename from mllib/src/main/scala/org/apache/spark/ml/feature/DiscreteCosineTransformer.scala rename to mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index a2f4d59f81c44..228347635c92b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DiscreteCosineTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -36,8 +36,8 @@ import org.apache.spark.sql.types.DataType * More information on [[https://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II Wikipedia]]. */ @Experimental -class DiscreteCosineTransformer(override val uid: String) - extends UnaryTransformer[Vector, Vector, DiscreteCosineTransformer] { +class DCT(override val uid: String) + extends UnaryTransformer[Vector, Vector, DCT] { def this() = this(Identifiable.randomUID("dct")) diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDiscreteCosineTransformerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java similarity index 90% rename from mllib/src/test/java/org/apache/spark/ml/feature/JavaDiscreteCosineTransformerSuite.java rename to mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java index 28bc5f65e0532..845eed61c45c6 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDiscreteCosineTransformerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java @@ -37,13 +37,13 @@ import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; -public class JavaDiscreteCosineTransformerSuite { +public class JavaDCTSuite { private transient JavaSparkContext jsc; private transient SQLContext jsql; @Before public void setUp() { - jsc = new JavaSparkContext("local", "JavaDiscreteCosineTransformerSuite"); + jsc = new JavaSparkContext("local", "JavaDCTSuite"); jsql = new SQLContext(jsc); } @@ -66,11 +66,11 @@ public void javaCompatibilityTest() { double[] expectedResult = input.clone(); (new DoubleDCT_1D(input.length)).forward(expectedResult, true); - DiscreteCosineTransformer DCT = new DiscreteCosineTransformer() + DCT dct = new DCT() .setInputCol("vec") .setOutputCol("resultVec"); - Row[] result = DCT.transform(dataset).select("resultVec").collect(); + Row[] result = dct.transform(dataset).select("resultVec").collect(); Vector resultVec = result[0].getAs("resultVec"); Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6); diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/DiscreteCosineTransformerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala similarity index 94% rename from mllib/src/test/scala/org/apache/spark/ml/feature/DiscreteCosineTransformerSuite.scala rename to mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala index ed0fc11f78f69..37ed2367c33f7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/DiscreteCosineTransformerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/DCTSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.{DataFrame, Row} @BeanInfo case class DCTTestData(vec: Vector, wantedVec: Vector) -class DiscreteCosineTransformerSuite extends SparkFunSuite with MLlibTestSparkContext { +class DCTSuite extends SparkFunSuite with MLlibTestSparkContext { test("forward transform of discrete cosine matches jTransforms result") { val data = Vectors.dense((0 until 128).map(_ => 2D * math.random - 1D).toArray) @@ -58,7 +58,7 @@ class DiscreteCosineTransformerSuite extends SparkFunSuite with MLlibTestSparkCo DCTTestData(data, expectedResult) )) - val transformer = new DiscreteCosineTransformer() + val transformer = new DCT() .setInputCol("vec") .setOutputCol("resultVec") .setInverse(inverse) From b6e76edf3005c078b407f63b0a05d3a28c18c742 Mon Sep 17 00:00:00 2001 From: x1- Date: Tue, 30 Jun 2015 20:35:46 -0700 Subject: [PATCH 112/122] [SPARK-8535] [PYSPARK] PySpark : Can't create DataFrame from Pandas dataframe with no explicit column name Because implicit name of `pandas.columns` are Int, but `StructField` json expect `String`. So I think `pandas.columns` are should be convert to `String`. ### issue * [SPARK-8535 PySpark : Can't create DataFrame from Pandas dataframe with no explicit column name](https://issues.apache.org/jira/browse/SPARK-8535) Author: x1- Closes #7124 from x1-/SPARK-8535 and squashes the following commits: d68fd38 [x1-] modify unit-test using pandas. ea1897d [x1-] For implicit name of pandas.columns are Int, so should be convert to String. --- python/pyspark/sql/context.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 4bf232111c496..309c11faf9319 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -344,13 +344,15 @@ def createDataFrame(self, data, schema=None, samplingRatio=None): >>> sqlContext.createDataFrame(df.toPandas()).collect() # doctest: +SKIP [Row(name=u'Alice', age=1)] + >>> sqlContext.createDataFrame(pandas.DataFrame([[1, 2]]).collect()) # doctest: +SKIP + [Row(0=1, 1=2)] """ if isinstance(data, DataFrame): raise TypeError("data is already a DataFrame") if has_pandas and isinstance(data, pandas.DataFrame): if schema is None: - schema = list(data.columns) + schema = [str(x) for x in data.columns] data = [r.tolist() for r in data.to_records(index=False)] if not isinstance(data, RDD): From 64c14618d3f4ede042bd3f6a542bc17a730afb0e Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 30 Jun 2015 21:57:07 -0700 Subject: [PATCH 113/122] [SPARK-6602][Core]Remove unnecessary synchronized A follow-up pr to address https://github.com/apache/spark/pull/5392#discussion_r33627528 Author: zsxwing Closes #7141 from zsxwing/pr5392-follow-up and squashes the following commits: fcf7b50 [zsxwing] Remove unnecessary synchronized --- .../main/scala/org/apache/spark/deploy/master/Master.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 3e7c16722805e..48070768f6edb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -518,12 +518,9 @@ private[master] class Master( } private def completeRecovery() { - // TODO Why synchronized // Ensure "only-once" recovery semantics using a short synchronization period. - synchronized { - if (state != RecoveryState.RECOVERING) { return } - state = RecoveryState.COMPLETING_RECOVERY - } + if (state != RecoveryState.RECOVERING) { return } + state = RecoveryState.COMPLETING_RECOVERY // Kill off any workers and apps that didn't respond to us. workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker) From 365c14055e90db5ea4b25afec03022be81c8a704 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 30 Jun 2015 23:04:54 -0700 Subject: [PATCH 114/122] [SPARK-8748][SQL] Move castability test out from Cast case class into Cast object. This patch moved resolve function in Cast case class into the companion object, and renamed it canCast. We can then use this in the analyzer without a Cast expr. Author: Reynold Xin Closes #7145 from rxin/cast and squashes the following commits: cd086a9 [Reynold Xin] Whitespace changes. 4d2d989 [Reynold Xin] [SPARK-8748][SQL] Move castability test out from Cast case class into Cast object. --- .../spark/sql/catalyst/expressions/Cast.scala | 144 ++++++++++-------- 1 file changed, 78 insertions(+), 66 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d69d490ad666a..2d99d1a3fe8dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -27,23 +27,65 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -/** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { - override def checkInputDataTypes(): TypeCheckResult = { - if (resolve(child.dataType, dataType)) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure( - s"cannot cast ${child.dataType} to $dataType") - } - } +object Cast { - override def foldable: Boolean = child.foldable + /** + * Returns true iff we can cast `from` type to `to` type. + */ + def canCast(from: DataType, to: DataType): Boolean = (from, to) match { + case (fromType, toType) if fromType == toType => true + + case (NullType, _) => true + + case (_, StringType) => true - override def nullable: Boolean = forceNullable(child.dataType, dataType) || child.nullable + case (StringType, BinaryType) => true - private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match { + case (StringType, BooleanType) => true + case (DateType, BooleanType) => true + case (TimestampType, BooleanType) => true + case (_: NumericType, BooleanType) => true + + case (StringType, TimestampType) => true + case (BooleanType, TimestampType) => true + case (DateType, TimestampType) => true + case (_: NumericType, TimestampType) => true + + case (_, DateType) => true + + case (StringType, _: NumericType) => true + case (BooleanType, _: NumericType) => true + case (DateType, _: NumericType) => true + case (TimestampType, _: NumericType) => true + case (_: NumericType, _: NumericType) => true + + case (ArrayType(fromType, fn), ArrayType(toType, tn)) => + canCast(fromType, toType) && + resolvableNullability(fn || forceNullable(fromType, toType), tn) + + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + canCast(fromKey, toKey) && + (!forceNullable(fromKey, toKey)) && + canCast(fromValue, toValue) && + resolvableNullability(fn || forceNullable(fromValue, toValue), tn) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { + case (fromField, toField) => + canCast(fromField.dataType, toField.dataType) && + resolvableNullability( + fromField.nullable || forceNullable(fromField.dataType, toField.dataType), + toField.nullable) + } + + case _ => false + } + + private def resolvableNullability(from: Boolean, to: Boolean) = !from || to + + private def forceNullable(from: DataType, to: DataType) = (from, to) match { case (StringType, _: NumericType) => true case (StringType, TimestampType) => true case (DoubleType, TimestampType) => true @@ -58,61 +100,24 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null case _ => false } +} - private[this] def resolvableNullability(from: Boolean, to: Boolean) = !from || to - - private[this] def resolve(from: DataType, to: DataType): Boolean = { - (from, to) match { - case (from, to) if from == to => true - - case (NullType, _) => true - - case (_, StringType) => true - - case (StringType, BinaryType) => true - - case (StringType, BooleanType) => true - case (DateType, BooleanType) => true - case (TimestampType, BooleanType) => true - case (_: NumericType, BooleanType) => true - - case (StringType, TimestampType) => true - case (BooleanType, TimestampType) => true - case (DateType, TimestampType) => true - case (_: NumericType, TimestampType) => true - - case (_, DateType) => true - - case (StringType, _: NumericType) => true - case (BooleanType, _: NumericType) => true - case (DateType, _: NumericType) => true - case (TimestampType, _: NumericType) => true - case (_: NumericType, _: NumericType) => true - - case (ArrayType(from, fn), ArrayType(to, tn)) => - resolve(from, to) && - resolvableNullability(fn || forceNullable(from, to), tn) - - case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => - resolve(fromKey, toKey) && - (!forceNullable(fromKey, toKey)) && - resolve(fromValue, toValue) && - resolvableNullability(fn || forceNullable(fromValue, toValue), tn) - - case (StructType(fromFields), StructType(toFields)) => - fromFields.size == toFields.size && - fromFields.zip(toFields).forall { - case (fromField, toField) => - resolve(fromField.dataType, toField.dataType) && - resolvableNullability( - fromField.nullable || forceNullable(fromField.dataType, toField.dataType), - toField.nullable) - } +/** Cast the child expression to the target data type. */ +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging { - case _ => false + override def checkInputDataTypes(): TypeCheckResult = { + if (Cast.canCast(child.dataType, dataType)) { + TypeCheckResult.TypeCheckSuccess + } else { + TypeCheckResult.TypeCheckFailure( + s"cannot cast ${child.dataType} to $dataType") } } + override def foldable: Boolean = child.foldable + + override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable + override def toString: String = s"CAST($child, $dataType)" // [[func]] assumes the input is no longer null because eval already does the null check. @@ -172,7 +177,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w catch { case _: java.lang.IllegalArgumentException => null } }) case BooleanType => - buildCast[Boolean](_, b => (if (b) 1L else 0)) + buildCast[Boolean](_, b => if (b) 1L else 0) case LongType => buildCast[Long](_, l => longToTimestamp(l)) case IntegerType => @@ -388,7 +393,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (fromField, toField) => cast(fromField.dataType, toField.dataType) } // TODO: Could be faster? - val newRow = new GenericMutableRow(from.fields.size) + val newRow = new GenericMutableRow(from.fields.length) buildCast[InternalRow](_, row => { var i = 0 while (i < row.length) { @@ -427,20 +432,23 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - // TODO(cg): Add support for more data types. + // TODO: Add support for more data types. (child.dataType, dataType) match { case (BinaryType, StringType) => defineCodeGen (ctx, ev, c => s"${ctx.stringType}.fromBytes($c)") + case (DateType, StringType) => defineCodeGen(ctx, ev, c => s"""${ctx.stringType}.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.dateToString($c))""") + case (TimestampType, StringType) => defineCodeGen(ctx, ev, c => s"""${ctx.stringType}.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c))""") + case (_, StringType) => defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))") @@ -450,12 +458,16 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (BooleanType, dt: NumericType) => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c ? 1 : 0)") + case (dt: DecimalType, BooleanType) => defineCodeGen(ctx, ev, c => s"!$c.isZero()") + case (dt: NumericType, BooleanType) => defineCodeGen(ctx, ev, c => s"$c != 0") + case (_: DecimalType, dt: NumericType) => defineCodeGen(ctx, ev, c => s"($c).to${ctx.primitiveTypeName(dt)}()") + case (_: NumericType, dt: NumericType) => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})($c)") From fc3a6fe67f5aeda2443958c31f097daeba8549e5 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Jul 2015 00:08:16 -0700 Subject: [PATCH 115/122] [SPARK-8749][SQL] Remove HiveTypeCoercion trait. Moved all the rules into the companion object. Author: Reynold Xin Closes #7147 from rxin/SPARK-8749 and squashes the following commits: c1c6dc0 [Reynold Xin] [SPARK-8749][SQL] Remove HiveTypeCoercion trait. --- .../sql/catalyst/analysis/Analyzer.scala | 4 +- .../catalyst/analysis/HiveTypeCoercion.scala | 59 ++++++++----------- .../analysis/HiveTypeCoercionSuite.scala | 14 ++--- 3 files changed, 33 insertions(+), 44 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 117c87a785fdb..15e84e68b9881 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -43,7 +43,7 @@ class Analyzer( registry: FunctionRegistry, conf: CatalystConf, maxIterations: Int = 100) - extends RuleExecutor[LogicalPlan] with HiveTypeCoercion with CheckAnalysis { + extends RuleExecutor[LogicalPlan] with CheckAnalysis { def resolver: Resolver = { if (conf.caseSensitiveAnalysis) { @@ -76,7 +76,7 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: UnresolvedHavingClauseAttributes :: - typeCoercionRules ++ + HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*) ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e525ad623ff12..a9d396d1faeeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -22,7 +22,32 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ + +/** + * A collection of [[Rule Rules]] that can be used to coerce differing types that + * participate in operations into compatible ones. Most of these rules are based on Hive semantics, + * but they do not introduce any dependencies on the hive codebase. For this reason they remain in + * Catalyst until we have a more standard set of coercions. + */ object HiveTypeCoercion { + + val typeCoercionRules = + PropagateTypes :: + ConvertNaNs :: + InConversion :: + WidenTypes :: + PromoteStrings :: + DecimalPrecision :: + BooleanEquality :: + StringToIntegralCasts :: + FunctionArgumentConversion :: + CaseWhenCoercion :: + IfCoercion :: + Division :: + PropagateTypes :: + AddCastForAutoCastInputTypes :: + Nil + // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. // The conversion for integral and floating point types have a linear widening hierarchy: private val numericPrecedence = @@ -79,7 +104,6 @@ object HiveTypeCoercion { }) } - /** * Find the tightest common type of a set of types by continuously applying * `findTightestCommonTypeOfTwo` on these types. @@ -90,34 +114,6 @@ object HiveTypeCoercion { case Some(d) => findTightestCommonTypeOfTwo(d, c) }) } -} - -/** - * A collection of [[Rule Rules]] that can be used to coerce differing types that - * participate in operations into compatible ones. Most of these rules are based on Hive semantics, - * but they do not introduce any dependencies on the hive codebase. For this reason they remain in - * Catalyst until we have a more standard set of coercions. - */ -trait HiveTypeCoercion { - - import HiveTypeCoercion._ - - val typeCoercionRules = - PropagateTypes :: - ConvertNaNs :: - InConversion :: - WidenTypes :: - PromoteStrings :: - DecimalPrecision :: - BooleanEquality :: - StringToIntegralCasts :: - FunctionArgumentConversion :: - CaseWhenCoercion :: - IfCoercion :: - Division :: - PropagateTypes :: - AddCastForAutoCastInputTypes :: - Nil /** * Applies any changes to [[AttributeReference]] data types that are made by other rules to @@ -202,8 +198,6 @@ trait HiveTypeCoercion { * - LongType to DoubleType */ object WidenTypes extends Rule[LogicalPlan] { - import HiveTypeCoercion._ - def apply(plan: LogicalPlan): LogicalPlan = plan transform { // TODO: unions with fixed-precision decimals case u @ Union(left, right) if u.childrenResolved && !u.resolved => @@ -655,8 +649,6 @@ trait HiveTypeCoercion { * Coerces the type of different branches of a CASE WHEN statement to a common type. */ object CaseWhenCoercion extends Rule[LogicalPlan] { - import HiveTypeCoercion._ - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") @@ -714,7 +706,6 @@ trait HiveTypeCoercion { * [[AutoCastInputTypes]]. */ object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index f7b8e21bed490..eae3666595a38 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -113,8 +113,7 @@ class HiveTypeCoercionSuite extends PlanTest { } test("coalesce casts") { - val fac = new HiveTypeCoercion { }.FunctionArgumentConversion - ruleTest(fac, + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, Coalesce(Literal(1.0) :: Literal(1) :: Literal.create(1.0, FloatType) @@ -123,7 +122,7 @@ class HiveTypeCoercionSuite extends PlanTest { :: Cast(Literal(1), DoubleType) :: Cast(Literal.create(1.0, FloatType), DoubleType) :: Nil)) - ruleTest(fac, + ruleTest(HiveTypeCoercion.FunctionArgumentConversion, Coalesce(Literal(1L) :: Literal(1) :: Literal(new java.math.BigDecimal("1000000000000000000000")) @@ -135,7 +134,7 @@ class HiveTypeCoercionSuite extends PlanTest { } test("type coercion for If") { - val rule = new HiveTypeCoercion { }.IfCoercion + val rule = HiveTypeCoercion.IfCoercion ruleTest(rule, If(Literal(true), Literal(1), Literal(1L)), If(Literal(true), Cast(Literal(1), LongType), Literal(1L)) @@ -148,19 +147,18 @@ class HiveTypeCoercionSuite extends PlanTest { } test("type coercion for CaseKeyWhen") { - val cwc = new HiveTypeCoercion {}.CaseWhenCoercion - ruleTest(cwc, + ruleTest(HiveTypeCoercion.CaseWhenCoercion, CaseKeyWhen(Literal(1.toShort), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Cast(Literal(1.toShort), IntegerType), Seq(Literal(1), Literal("a"))) ) - ruleTest(cwc, + ruleTest(HiveTypeCoercion.CaseWhenCoercion, CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))), CaseKeyWhen(Literal(true), Seq(Literal(1), Literal("a"))) ) } test("type coercion simplification for equal to") { - val be = new HiveTypeCoercion {}.BooleanEquality + val be = HiveTypeCoercion.BooleanEquality ruleTest(be, EqualTo(Literal(true), Literal(1)), From 0eee0615894cda8ae1b2c8e61b8bda0ff648a219 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 1 Jul 2015 01:02:33 -0700 Subject: [PATCH 116/122] [SQL] [MINOR] remove internalRowRDD in DataFrame Developers have already familiar with `queryExecution.toRDD` as internal row RDD, and we should not add new concept. Author: Wenchen Fan Closes #7116 from cloud-fan/internal-rdd and squashes the following commits: 24756ca [Wenchen Fan] remove internalRowRDD --- sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala | 4 +--- .../org/apache/spark/sql/execution/stat/FrequentItems.scala | 2 +- .../org/apache/spark/sql/execution/stat/StatFunctions.scala | 2 +- .../main/scala/org/apache/spark/sql/sources/commands.scala | 4 ++-- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 8fe1f7e34cb5e..caad2da80b1eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1469,14 +1469,12 @@ class DataFrame private[sql]( lazy val rdd: RDD[Row] = { // use a local variable to make sure the map closure doesn't capture the whole DataFrame val schema = this.schema - internalRowRdd.mapPartitions { rows => + queryExecution.toRdd.mapPartitions { rows => val converter = CatalystTypeConverters.createToScalaConverter(schema) rows.map(converter(_).asInstanceOf[Row]) } } - private[sql] def internalRowRdd = queryExecution.executedPlan.execute() - /** * Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s. * @group rdd diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index 3ebbf96090a55..4e2e2c210d5a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -90,7 +90,7 @@ private[sql] object FrequentItems extends Logging { (name, originalSchema.fields(index).dataType) } - val freqItems = df.select(cols.map(Column(_)) : _*).internalRowRdd.aggregate(countMaps)( + val freqItems = df.select(cols.map(Column(_)) : _*).queryExecution.toRdd.aggregate(countMaps)( seqOp = (counts, row) => { var i = 0 while (i < numCols) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index b624ef7e8fa1a..23ddfa9839e5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -82,7 +82,7 @@ private[sql] object StatFunctions extends Logging { s"with dataType ${data.get.dataType} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) - df.select(columns: _*).internalRowRdd.aggregate(new CovarianceCounter)( + df.select(columns: _*).queryExecution.toRdd.aggregate(new CovarianceCounter)( seqOp = (counter, row) => { counter.add(row.getDouble(0), row.getDouble(1)) }, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala index 42b51caab5ce9..7214eb0b4169a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/commands.scala @@ -154,7 +154,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.driverSideSetup() try { - df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _) + df.sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writeRows _) writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => @@ -220,7 +220,7 @@ private[sql] case class InsertIntoHadoopFsRelation( writerContainer.driverSideSetup() try { - df.sqlContext.sparkContext.runJob(df.internalRowRdd, writeRows _) + df.sqlContext.sparkContext.runJob(df.queryExecution.toRdd, writeRows _) writerContainer.commitJob() relation.refresh() } catch { case cause: Throwable => From 97652416e22ae7d4c471178377a7dda61afb1f7a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Jul 2015 01:08:20 -0700 Subject: [PATCH 117/122] [SPARK-8750][SQL] Remove the closure in functions.callUdf. Author: Reynold Xin Closes #7148 from rxin/calludf-closure and squashes the following commits: 00df372 [Reynold Xin] Fixed index out of bound exception. 4beba76 [Reynold Xin] [SPARK-8750][SQL] Remove the closure in functions.callUdf. --- .../main/scala/org/apache/spark/sql/functions.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5767668dd339b..4e8f3f96bf4db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1829,7 +1829,15 @@ object functions { */ @deprecated("Use callUDF", "1.5.0") def callUdf(udfName: String, cols: Column*): Column = { - UnresolvedFunction(udfName, cols.map(_.expr)) + // Note: we avoid using closures here because on file systems that are case-insensitive, the + // compiled class file for the closure here will conflict with the one in callUDF (upper case). + val exprs = new Array[Expression](cols.size) + var i = 0 + while (i < cols.size) { + exprs(i) = cols(i).expr + i += 1 + } + UnresolvedFunction(udfName, exprs) } } From fdcad6ef48a9e790776c316124bd6478ab6bd5c8 Mon Sep 17 00:00:00 2001 From: cocoatomo Date: Wed, 1 Jul 2015 09:37:09 -0700 Subject: [PATCH 118/122] [SPARK-8763] [PYSPARK] executing run-tests.py with Python 2.6 fails with absence of subprocess.check_output function Running run-tests.py with Python 2.6 cause following error: ``` Running PySpark tests. Output is in python//Users/tomohiko/.jenkins/jobs/pyspark_test/workspace/python/unit-tests.log Will test against the following Python executables: ['python2.6', 'python3.4', 'pypy'] Will test the following Python modules: ['pyspark-core', 'pyspark-ml', 'pyspark-mllib', 'pyspark-sql', 'pyspark-streaming'] Traceback (most recent call last): File "./python/run-tests.py", line 196, in main() File "./python/run-tests.py", line 159, in main python_implementation = subprocess.check_output( AttributeError: 'module' object has no attribute 'check_output' ... ``` The cause of this error is using subprocess.check_output function, which exists since Python 2.7. (ref. https://docs.python.org/2.7/library/subprocess.html#subprocess.check_output) Author: cocoatomo Closes #7161 from cocoatomo/issues/8763-test-fails-py26 and squashes the following commits: cf4f901 [cocoatomo] [SPARK-8763] backport process.check_output function from Python 2.7 --- python/run-tests.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/python/run-tests.py b/python/run-tests.py index b7737650daa54..7638854def2e8 100755 --- a/python/run-tests.py +++ b/python/run-tests.py @@ -31,6 +31,23 @@ import Queue else: import queue as Queue +if sys.version_info >= (2, 7): + subprocess_check_output = subprocess.check_output +else: + # SPARK-8763 + # backported from subprocess module in Python 2.7 + def subprocess_check_output(*popenargs, **kwargs): + if 'stdout' in kwargs: + raise ValueError('stdout argument not allowed, it will be overridden.') + process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) + output, unused_err = process.communicate() + retcode = process.poll() + if retcode: + cmd = kwargs.get("args") + if cmd is None: + cmd = popenargs[0] + raise subprocess.CalledProcessError(retcode, cmd, output=output) + return output # Append `SPARK_HOME/dev` to the Python path so that we can import the sparktestsupport module @@ -156,11 +173,11 @@ def main(): task_queue = Queue.Queue() for python_exec in python_execs: - python_implementation = subprocess.check_output( + python_implementation = subprocess_check_output( [python_exec, "-c", "import platform; print(platform.python_implementation())"], universal_newlines=True).strip() LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation) - LOGGER.debug("%s version is: %s", python_exec, subprocess.check_output( + LOGGER.debug("%s version is: %s", python_exec, subprocess_check_output( [python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip()) for module in modules_to_test: if python_implementation not in module.blacklisted_python_implementations: From 69c5dee2f01b1ae35bd813d31d46429a32cb475d Mon Sep 17 00:00:00 2001 From: Sun Rui Date: Wed, 1 Jul 2015 09:50:12 -0700 Subject: [PATCH 119/122] [SPARK-7714] [SPARKR] SparkR tests should use more specific expectations than expect_true 1. Update the pattern 'expect_true(a == b)' to 'expect_equal(a, b)'. 2. Update the pattern 'expect_true(inherits(a, b))' to 'expect_is(a, b)'. 3. Update the pattern 'expect_true(identical(a, b))' to 'expect_identical(a, b)'. Author: Sun Rui Closes #7152 from sun-rui/SPARK-7714 and squashes the following commits: 8ad2440 [Sun Rui] Fix test case errors. 8fe9f0c [Sun Rui] Update the pattern 'expect_true(identical(a, b))' to 'expect_identical(a, b)'. f1b8005 [Sun Rui] Update the pattern 'expect_true(inherits(a, b))' to 'expect_is(a, b)'. f631e94 [Sun Rui] Update the pattern 'expect_true(a == b)' to 'expect_equal(a, b)'. --- R/pkg/inst/tests/test_binaryFile.R | 2 +- R/pkg/inst/tests/test_binary_function.R | 4 +- R/pkg/inst/tests/test_includeJAR.R | 4 +- R/pkg/inst/tests/test_parallelize_collect.R | 2 +- R/pkg/inst/tests/test_rdd.R | 4 +- R/pkg/inst/tests/test_sparkSQL.R | 354 ++++++++++---------- R/pkg/inst/tests/test_take.R | 8 +- R/pkg/inst/tests/test_textFile.R | 6 +- R/pkg/inst/tests/test_utils.R | 4 +- 9 files changed, 194 insertions(+), 194 deletions(-) diff --git a/R/pkg/inst/tests/test_binaryFile.R b/R/pkg/inst/tests/test_binaryFile.R index 4db7266abc8e2..ccaea18ecab2a 100644 --- a/R/pkg/inst/tests/test_binaryFile.R +++ b/R/pkg/inst/tests/test_binaryFile.R @@ -82,7 +82,7 @@ test_that("saveAsObjectFile()/objectFile() works with multiple paths", { saveAsObjectFile(rdd2, fileName2) rdd <- objectFile(sc, c(fileName1, fileName2)) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName1, recursive = TRUE) unlink(fileName2, recursive = TRUE) diff --git a/R/pkg/inst/tests/test_binary_function.R b/R/pkg/inst/tests/test_binary_function.R index a1e354e567be5..3be8c65a6c1a0 100644 --- a/R/pkg/inst/tests/test_binary_function.R +++ b/R/pkg/inst/tests/test_binary_function.R @@ -38,13 +38,13 @@ test_that("union on two RDDs", { union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, c(as.list(nums), mockFile)) - expect_true(getSerializedMode(union.rdd) == "byte") + expect_equal(getSerializedMode(union.rdd), "byte") rdd<- map(text.rdd, function(x) {x}) union.rdd <- unionRDD(rdd, text.rdd) actual <- collect(union.rdd) expect_equal(actual, as.list(c(mockFile, mockFile))) - expect_true(getSerializedMode(union.rdd) == "byte") + expect_equal(getSerializedMode(union.rdd), "byte") unlink(fileName) }) diff --git a/R/pkg/inst/tests/test_includeJAR.R b/R/pkg/inst/tests/test_includeJAR.R index 8bc693be20c3c..844d86f3cc97f 100644 --- a/R/pkg/inst/tests/test_includeJAR.R +++ b/R/pkg/inst/tests/test_includeJAR.R @@ -31,7 +31,7 @@ runScript <- function() { test_that("sparkJars tag in SparkContext", { testOutput <- runScript() helloTest <- testOutput[1] - expect_true(helloTest == "Hello, Dave") + expect_equal(helloTest, "Hello, Dave") basicFunction <- testOutput[2] - expect_true(basicFunction == 4L) + expect_equal(basicFunction, "4") }) diff --git a/R/pkg/inst/tests/test_parallelize_collect.R b/R/pkg/inst/tests/test_parallelize_collect.R index fff028657db37..2552127cc547f 100644 --- a/R/pkg/inst/tests/test_parallelize_collect.R +++ b/R/pkg/inst/tests/test_parallelize_collect.R @@ -57,7 +57,7 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { strListRDD2) for (rdd in rdds) { - expect_true(inherits(rdd, "RDD")) + expect_is(rdd, "RDD") expect_true(.hasSlot(rdd, "jrdd") && inherits(rdd@jrdd, "jobj") && isInstanceOf(rdd@jrdd, "org.apache.spark.api.java.JavaRDD")) diff --git a/R/pkg/inst/tests/test_rdd.R b/R/pkg/inst/tests/test_rdd.R index 4fe653856756e..fc3c01d837de4 100644 --- a/R/pkg/inst/tests/test_rdd.R +++ b/R/pkg/inst/tests/test_rdd.R @@ -33,9 +33,9 @@ test_that("get number of partitions in RDD", { }) test_that("first on RDD", { - expect_true(first(rdd) == 1) + expect_equal(first(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) - expect_true(first(newrdd) == 2) + expect_equal(first(newrdd), 2) }) test_that("count and length on RDD", { diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index 6a08f894313c4..0e4235ea8b4b3 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -61,7 +61,7 @@ test_that("infer types", { expect_equal(infer_type(list(1L, 2L)), list(type = 'array', elementType = "integer", containsNull = TRUE)) testStruct <- infer_type(list(a = 1L, b = "2")) - expect_true(class(testStruct) == "structType") + expect_equal(class(testStruct), "structType") checkStructField(testStruct$fields()[[1]], "a", "IntegerType", TRUE) checkStructField(testStruct$fields()[[2]], "b", "StringType", TRUE) e <- new.env() @@ -73,39 +73,39 @@ test_that("infer types", { test_that("structType and structField", { testField <- structField("a", "string") - expect_true(inherits(testField, "structField")) - expect_true(testField$name() == "a") + expect_is(testField, "structField") + expect_equal(testField$name(), "a") expect_true(testField$nullable()) testSchema <- structType(testField, structField("b", "integer")) - expect_true(inherits(testSchema, "structType")) - expect_true(inherits(testSchema$fields()[[2]], "structField")) - expect_true(testSchema$fields()[[1]]$dataType.toString() == "StringType") + expect_is(testSchema, "structType") + expect_is(testSchema$fields()[[2]], "structField") + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "StringType") }) test_that("create DataFrame from RDD", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(sqlContext, rdd, list("a", "b")) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) df <- createDataFrame(sqlContext, rdd) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) df <- createDataFrame(sqlContext, rdd, schema) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) df <- createDataFrame(sqlContext, rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) }) @@ -150,26 +150,26 @@ test_that("convert NAs to null type in DataFrames", { test_that("toDF", { rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) df <- toDF(rdd) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("_1", "_2")) schema <- structType(structField(x = "a", type = "integer", nullable = TRUE), structField(x = "b", type = "string", nullable = TRUE)) df <- toDF(rdd, schema) - expect_true(inherits(df, "DataFrame")) + expect_is(df, "DataFrame") expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) rdd <- lapply(parallelize(sc, 1:10), function(x) { list(a = x, b = as.character(x)) }) df <- toDF(rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 10) + expect_is(df, "DataFrame") + expect_equal(count(df), 10) expect_equal(columns(df), c("a", "b")) expect_equal(dtypes(df), list(c("a", "int"), c("b", "string"))) }) @@ -219,21 +219,21 @@ test_that("create DataFrame with different data types", { test_that("jsonFile() on a local file returns a DataFrame", { df <- jsonFile(sqlContext, jsonPath) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) }) test_that("jsonRDD() on a RDD with json string", { rdd <- parallelize(sc, mockLines) - expect_true(count(rdd) == 3) + expect_equal(count(rdd), 3) df <- jsonRDD(sqlContext, rdd) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) rdd2 <- flatMap(rdd, function(x) c(x, x)) df <- jsonRDD(sqlContext, rdd2) - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 6) + expect_is(df, "DataFrame") + expect_equal(count(df), 6) }) test_that("test cache, uncache and clearCache", { @@ -248,9 +248,9 @@ test_that("test cache, uncache and clearCache", { test_that("test tableNames and tables", { df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") - expect_true(length(tableNames(sqlContext)) == 1) + expect_equal(length(tableNames(sqlContext)), 1) df <- tables(sqlContext) - expect_true(count(df) == 1) + expect_equal(count(df), 1) dropTempTable(sqlContext, "table1") }) @@ -258,8 +258,8 @@ test_that("registerTempTable() results in a queryable table and sql() results in df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") newdf <- sql(sqlContext, "SELECT * FROM table1 where name = 'Michael'") - expect_true(inherits(newdf, "DataFrame")) - expect_true(count(newdf) == 1) + expect_is(newdf, "DataFrame") + expect_equal(count(newdf), 1) dropTempTable(sqlContext, "table1") }) @@ -279,14 +279,14 @@ test_that("insertInto() on a registered table", { registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1") - expect_true(count(sql(sqlContext, "select * from table1")) == 5) - expect_true(first(sql(sqlContext, "select * from table1 order by age"))$name == "Michael") + expect_equal(count(sql(sqlContext, "select * from table1")), 5) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Michael") dropTempTable(sqlContext, "table1") registerTempTable(dfParquet, "table1") insertInto(dfParquet2, "table1", overwrite = TRUE) - expect_true(count(sql(sqlContext, "select * from table1")) == 2) - expect_true(first(sql(sqlContext, "select * from table1 order by age"))$name == "Bob") + expect_equal(count(sql(sqlContext, "select * from table1")), 2) + expect_equal(first(sql(sqlContext, "select * from table1 order by age"))$name, "Bob") dropTempTable(sqlContext, "table1") }) @@ -294,16 +294,16 @@ test_that("table() returns a new DataFrame", { df <- jsonFile(sqlContext, jsonPath) registerTempTable(df, "table1") tabledf <- table(sqlContext, "table1") - expect_true(inherits(tabledf, "DataFrame")) - expect_true(count(tabledf) == 3) + expect_is(tabledf, "DataFrame") + expect_equal(count(tabledf), 3) dropTempTable(sqlContext, "table1") }) test_that("toRDD() returns an RRDD", { df <- jsonFile(sqlContext, jsonPath) testRDD <- toRDD(df) - expect_true(inherits(testRDD, "RDD")) - expect_true(count(testRDD) == 3) + expect_is(testRDD, "RDD") + expect_equal(count(testRDD), 3) }) test_that("union on two RDDs created from DataFrames returns an RRDD", { @@ -311,9 +311,9 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { RDD1 <- toRDD(df) RDD2 <- toRDD(df) unioned <- unionRDD(RDD1, RDD2) - expect_true(inherits(unioned, "RDD")) - expect_true(SparkR:::getSerializedMode(unioned) == "byte") - expect_true(collect(unioned)[[2]]$name == "Andy") + expect_is(unioned, "RDD") + expect_equal(SparkR:::getSerializedMode(unioned), "byte") + expect_equal(collect(unioned)[[2]]$name, "Andy") }) test_that("union on mixed serialization types correctly returns a byte RRDD", { @@ -333,16 +333,16 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { dfRDD <- toRDD(df) unionByte <- unionRDD(rdd, dfRDD) - expect_true(inherits(unionByte, "RDD")) - expect_true(SparkR:::getSerializedMode(unionByte) == "byte") - expect_true(collect(unionByte)[[1]] == 1) - expect_true(collect(unionByte)[[12]]$name == "Andy") + expect_is(unionByte, "RDD") + expect_equal(SparkR:::getSerializedMode(unionByte), "byte") + expect_equal(collect(unionByte)[[1]], 1) + expect_equal(collect(unionByte)[[12]]$name, "Andy") unionString <- unionRDD(textRDD, dfRDD) - expect_true(inherits(unionString, "RDD")) - expect_true(SparkR:::getSerializedMode(unionString) == "byte") - expect_true(collect(unionString)[[1]] == "Michael") - expect_true(collect(unionString)[[5]]$name == "Andy") + expect_is(unionString, "RDD") + expect_equal(SparkR:::getSerializedMode(unionString), "byte") + expect_equal(collect(unionString)[[1]], "Michael") + expect_equal(collect(unionString)[[5]]$name, "Andy") }) test_that("objectFile() works with row serialization", { @@ -352,7 +352,7 @@ test_that("objectFile() works with row serialization", { saveAsObjectFile(coalesce(dfRDD, 1L), objectPath) objectIn <- objectFile(sc, objectPath) - expect_true(inherits(objectIn, "RDD")) + expect_is(objectIn, "RDD") expect_equal(SparkR:::getSerializedMode(objectIn), "byte") expect_equal(collect(objectIn)[[2]]$age, 30) }) @@ -363,32 +363,32 @@ test_that("lapply() on a DataFrame returns an RDD with the correct columns", { row$newCol <- row$age + 5 row }) - expect_true(inherits(testRDD, "RDD")) + expect_is(testRDD, "RDD") collected <- collect(testRDD) - expect_true(collected[[1]]$name == "Michael") - expect_true(collected[[2]]$newCol == "35") + expect_equal(collected[[1]]$name, "Michael") + expect_equal(collected[[2]]$newCol, 35) }) test_that("collect() returns a data.frame", { df <- jsonFile(sqlContext, jsonPath) rdf <- collect(df) expect_true(is.data.frame(rdf)) - expect_true(names(rdf)[1] == "age") - expect_true(nrow(rdf) == 3) - expect_true(ncol(rdf) == 2) + expect_equal(names(rdf)[1], "age") + expect_equal(nrow(rdf), 3) + expect_equal(ncol(rdf), 2) }) test_that("limit() returns DataFrame with the correct number of rows", { df <- jsonFile(sqlContext, jsonPath) dfLimited <- limit(df, 2) - expect_true(inherits(dfLimited, "DataFrame")) - expect_true(count(dfLimited) == 2) + expect_is(dfLimited, "DataFrame") + expect_equal(count(dfLimited), 2) }) test_that("collect() and take() on a DataFrame return the same number of rows and columns", { df <- jsonFile(sqlContext, jsonPath) - expect_true(nrow(collect(df)) == nrow(take(df, 10))) - expect_true(ncol(collect(df)) == ncol(take(df, 10))) + expect_equal(nrow(collect(df)), nrow(take(df, 10))) + expect_equal(ncol(collect(df)), ncol(take(df, 10))) }) test_that("multiple pipeline transformations starting with a DataFrame result in an RDD with the correct values", { @@ -401,9 +401,9 @@ test_that("multiple pipeline transformations starting with a DataFrame result in row$testCol <- if (row$age == 35 && !is.na(row$age)) TRUE else FALSE row }) - expect_true(inherits(second, "RDD")) - expect_true(count(second) == 3) - expect_true(collect(second)[[2]]$age == 35) + expect_is(second, "RDD") + expect_equal(count(second), 3) + expect_equal(collect(second)[[2]]$age, 35) expect_true(collect(second)[[2]]$testCol) expect_false(collect(second)[[3]]$testCol) }) @@ -430,36 +430,36 @@ test_that("cache(), persist(), and unpersist() on a DataFrame", { test_that("schema(), dtypes(), columns(), names() return the correct values/format", { df <- jsonFile(sqlContext, jsonPath) testSchema <- schema(df) - expect_true(length(testSchema$fields()) == 2) - expect_true(testSchema$fields()[[1]]$dataType.toString() == "LongType") - expect_true(testSchema$fields()[[2]]$dataType.simpleString() == "string") - expect_true(testSchema$fields()[[1]]$name() == "age") + expect_equal(length(testSchema$fields()), 2) + expect_equal(testSchema$fields()[[1]]$dataType.toString(), "LongType") + expect_equal(testSchema$fields()[[2]]$dataType.simpleString(), "string") + expect_equal(testSchema$fields()[[1]]$name(), "age") testTypes <- dtypes(df) - expect_true(length(testTypes[[1]]) == 2) - expect_true(testTypes[[1]][1] == "age") + expect_equal(length(testTypes[[1]]), 2) + expect_equal(testTypes[[1]][1], "age") testCols <- columns(df) - expect_true(length(testCols) == 2) - expect_true(testCols[2] == "name") + expect_equal(length(testCols), 2) + expect_equal(testCols[2], "name") testNames <- names(df) - expect_true(length(testNames) == 2) - expect_true(testNames[2] == "name") + expect_equal(length(testNames), 2) + expect_equal(testNames[2], "name") }) test_that("head() and first() return the correct data", { df <- jsonFile(sqlContext, jsonPath) testHead <- head(df) - expect_true(nrow(testHead) == 3) - expect_true(ncol(testHead) == 2) + expect_equal(nrow(testHead), 3) + expect_equal(ncol(testHead), 2) testHead2 <- head(df, 2) - expect_true(nrow(testHead2) == 2) - expect_true(ncol(testHead2) == 2) + expect_equal(nrow(testHead2), 2) + expect_equal(ncol(testHead2), 2) testFirst <- first(df) - expect_true(nrow(testFirst) == 1) + expect_equal(nrow(testFirst), 1) }) test_that("distinct() on DataFrames", { @@ -472,15 +472,15 @@ test_that("distinct() on DataFrames", { df <- jsonFile(sqlContext, jsonPathWithDup) uniques <- distinct(df) - expect_true(inherits(uniques, "DataFrame")) - expect_true(count(uniques) == 3) + expect_is(uniques, "DataFrame") + expect_equal(count(uniques), 3) }) test_that("sample on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) sampled <- sample(df, FALSE, 1.0) expect_equal(nrow(collect(sampled)), count(df)) - expect_true(inherits(sampled, "DataFrame")) + expect_is(sampled, "DataFrame") sampled2 <- sample(df, FALSE, 0.1) expect_true(count(sampled2) < 3) @@ -491,15 +491,15 @@ test_that("sample on a DataFrame", { test_that("select operators", { df <- select(jsonFile(sqlContext, jsonPath), "name", "age") - expect_true(inherits(df$name, "Column")) - expect_true(inherits(df[[2]], "Column")) - expect_true(inherits(df[["age"]], "Column")) + expect_is(df$name, "Column") + expect_is(df[[2]], "Column") + expect_is(df[["age"]], "Column") - expect_true(inherits(df[,1], "DataFrame")) + expect_is(df[,1], "DataFrame") expect_equal(columns(df[,1]), c("name")) expect_equal(columns(df[,"age"]), c("age")) df2 <- df[,c("age", "name")] - expect_true(inherits(df2, "DataFrame")) + expect_is(df2, "DataFrame") expect_equal(columns(df2), c("age", "name")) df$age2 <- df$age @@ -518,50 +518,50 @@ test_that("select operators", { test_that("select with column", { df <- jsonFile(sqlContext, jsonPath) df1 <- select(df, "name") - expect_true(columns(df1) == c("name")) - expect_true(count(df1) == 3) + expect_equal(columns(df1), c("name")) + expect_equal(count(df1), 3) df2 <- select(df, df$age) - expect_true(columns(df2) == c("age")) - expect_true(count(df2) == 3) + expect_equal(columns(df2), c("age")) + expect_equal(count(df2), 3) }) test_that("selectExpr() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) selected <- selectExpr(df, "age * 2") - expect_true(names(selected) == "(age * 2)") + expect_equal(names(selected), "(age * 2)") expect_equal(collect(selected), collect(select(df, df$age * 2L))) selected2 <- selectExpr(df, "name as newName", "abs(age) as age") expect_equal(names(selected2), c("newName", "age")) - expect_true(count(selected2) == 3) + expect_equal(count(selected2), 3) }) test_that("column calculation", { df <- jsonFile(sqlContext, jsonPath) d <- collect(select(df, alias(df$age + 1, "age2"))) - expect_true(names(d) == c("age2")) + expect_equal(names(d), c("age2")) df2 <- select(df, lower(df$name), abs(df$age)) - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) }) test_that("read.df() from json file", { df <- read.df(sqlContext, jsonPath, "json") - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) # Check if we can apply a user defined schema schema <- structType(structField("name", type = "string"), structField("age", type = "double")) df1 <- read.df(sqlContext, jsonPath, "json", schema) - expect_true(inherits(df1, "DataFrame")) + expect_is(df1, "DataFrame") expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) # Run the same with loadDF df2 <- loadDF(sqlContext, jsonPath, "json", schema) - expect_true(inherits(df2, "DataFrame")) + expect_is(df2, "DataFrame") expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) }) @@ -569,8 +569,8 @@ test_that("write.df() as parquet file", { df <- read.df(sqlContext, jsonPath, "json") write.df(df, parquetPath, "parquet", mode="overwrite") df2 <- read.df(sqlContext, parquetPath, "parquet") - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) }) test_that("test HiveContext", { @@ -580,17 +580,17 @@ test_that("test HiveContext", { skip("Hive is not build with SparkSQL, skipped") }) df <- createExternalTable(hiveCtx, "json", jsonPath, "json") - expect_true(inherits(df, "DataFrame")) - expect_true(count(df) == 3) + expect_is(df, "DataFrame") + expect_equal(count(df), 3) df2 <- sql(hiveCtx, "select * from json") - expect_true(inherits(df2, "DataFrame")) - expect_true(count(df2) == 3) + expect_is(df2, "DataFrame") + expect_equal(count(df2), 3) jsonPath2 <- tempfile(pattern="sparkr-test", fileext=".tmp") saveAsTable(df, "json", "json", "append", path = jsonPath2) df3 <- sql(hiveCtx, "select * from json") - expect_true(inherits(df3, "DataFrame")) - expect_true(count(df3) == 6) + expect_is(df3, "DataFrame") + expect_equal(count(df3), 6) }) test_that("column operators", { @@ -643,65 +643,65 @@ test_that("string operators", { test_that("group by", { df <- jsonFile(sqlContext, jsonPath) df1 <- agg(df, name = "max", age = "sum") - expect_true(1 == count(df1)) + expect_equal(1, count(df1)) df1 <- agg(df, age2 = max(df$age)) - expect_true(1 == count(df1)) + expect_equal(1, count(df1)) expect_equal(columns(df1), c("age2")) gd <- groupBy(df, "name") - expect_true(inherits(gd, "GroupedData")) + expect_is(gd, "GroupedData") df2 <- count(gd) - expect_true(inherits(df2, "DataFrame")) - expect_true(3 == count(df2)) + expect_is(df2, "DataFrame") + expect_equal(3, count(df2)) # Also test group_by, summarize, mean gd1 <- group_by(df, "name") - expect_true(inherits(gd1, "GroupedData")) + expect_is(gd1, "GroupedData") df_summarized <- summarize(gd, mean_age = mean(df$age)) - expect_true(inherits(df_summarized, "DataFrame")) - expect_true(3 == count(df_summarized)) + expect_is(df_summarized, "DataFrame") + expect_equal(3, count(df_summarized)) df3 <- agg(gd, age = "sum") - expect_true(inherits(df3, "DataFrame")) - expect_true(3 == count(df3)) + expect_is(df3, "DataFrame") + expect_equal(3, count(df3)) df3 <- agg(gd, age = sum(df$age)) - expect_true(inherits(df3, "DataFrame")) - expect_true(3 == count(df3)) + expect_is(df3, "DataFrame") + expect_equal(3, count(df3)) expect_equal(columns(df3), c("name", "age")) df4 <- sum(gd, "age") - expect_true(inherits(df4, "DataFrame")) - expect_true(3 == count(df4)) - expect_true(3 == count(mean(gd, "age"))) - expect_true(3 == count(max(gd, "age"))) + expect_is(df4, "DataFrame") + expect_equal(3, count(df4)) + expect_equal(3, count(mean(gd, "age"))) + expect_equal(3, count(max(gd, "age"))) }) test_that("arrange() and orderBy() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) sorted <- arrange(df, df$age) - expect_true(collect(sorted)[1,2] == "Michael") + expect_equal(collect(sorted)[1,2], "Michael") sorted2 <- arrange(df, "name") - expect_true(collect(sorted2)[2,"age"] == 19) + expect_equal(collect(sorted2)[2,"age"], 19) sorted3 <- orderBy(df, asc(df$age)) expect_true(is.na(first(sorted3)$age)) - expect_true(collect(sorted3)[2, "age"] == 19) + expect_equal(collect(sorted3)[2, "age"], 19) sorted4 <- orderBy(df, desc(df$name)) - expect_true(first(sorted4)$name == "Michael") - expect_true(collect(sorted4)[3,"name"] == "Andy") + expect_equal(first(sorted4)$name, "Michael") + expect_equal(collect(sorted4)[3,"name"], "Andy") }) test_that("filter() on a DataFrame", { df <- jsonFile(sqlContext, jsonPath) filtered <- filter(df, "age > 20") - expect_true(count(filtered) == 1) - expect_true(collect(filtered)$name == "Andy") + expect_equal(count(filtered), 1) + expect_equal(collect(filtered)$name, "Andy") filtered2 <- where(df, df$name != "Michael") - expect_true(count(filtered2) == 2) - expect_true(collect(filtered2)$age[2] == 19) + expect_equal(count(filtered2), 2) + expect_equal(collect(filtered2)$age[2], 19) # test suites for %in% filtered3 <- filter(df, "age in (19)") @@ -727,29 +727,29 @@ test_that("join() on a DataFrame", { joined <- join(df, df2) expect_equal(names(joined), c("age", "name", "name", "test")) - expect_true(count(joined) == 12) + expect_equal(count(joined), 12) joined2 <- join(df, df2, df$name == df2$name) expect_equal(names(joined2), c("age", "name", "name", "test")) - expect_true(count(joined2) == 3) + expect_equal(count(joined2), 3) joined3 <- join(df, df2, df$name == df2$name, "right_outer") expect_equal(names(joined3), c("age", "name", "name", "test")) - expect_true(count(joined3) == 4) + expect_equal(count(joined3), 4) expect_true(is.na(collect(orderBy(joined3, joined3$age))$age[2])) joined4 <- select(join(df, df2, df$name == df2$name, "outer"), alias(df$age + 5, "newAge"), df$name, df2$test) expect_equal(names(joined4), c("newAge", "name", "test")) - expect_true(count(joined4) == 4) + expect_equal(count(joined4), 4) expect_equal(collect(orderBy(joined4, joined4$name))$newAge[3], 24) }) test_that("toJSON() returns an RDD of the correct values", { df <- jsonFile(sqlContext, jsonPath) testRDD <- toJSON(df) - expect_true(inherits(testRDD, "RDD")) - expect_true(SparkR:::getSerializedMode(testRDD) == "string") + expect_is(testRDD, "RDD") + expect_equal(SparkR:::getSerializedMode(testRDD), "string") expect_equal(collect(testRDD)[[1]], mockLines[1]) }) @@ -775,50 +775,50 @@ test_that("unionAll(), except(), and intersect() on a DataFrame", { df2 <- read.df(sqlContext, jsonPath2, "json") unioned <- arrange(unionAll(df, df2), df$age) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(unioned) == 6) - expect_true(first(unioned)$name == "Michael") + expect_is(unioned, "DataFrame") + expect_equal(count(unioned), 6) + expect_equal(first(unioned)$name, "Michael") excepted <- arrange(except(df, df2), desc(df$age)) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(excepted) == 2) - expect_true(first(excepted)$name == "Justin") + expect_is(unioned, "DataFrame") + expect_equal(count(excepted), 2) + expect_equal(first(excepted)$name, "Justin") intersected <- arrange(intersect(df, df2), df$age) - expect_true(inherits(unioned, "DataFrame")) - expect_true(count(intersected) == 1) - expect_true(first(intersected)$name == "Andy") + expect_is(unioned, "DataFrame") + expect_equal(count(intersected), 1) + expect_equal(first(intersected)$name, "Andy") }) test_that("withColumn() and withColumnRenamed()", { df <- jsonFile(sqlContext, jsonPath) newDF <- withColumn(df, "newAge", df$age + 2) - expect_true(length(columns(newDF)) == 3) - expect_true(columns(newDF)[3] == "newAge") - expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) newDF2 <- withColumnRenamed(df, "age", "newerAge") - expect_true(length(columns(newDF2)) == 2) - expect_true(columns(newDF2)[1] == "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") }) test_that("mutate() and rename()", { df <- jsonFile(sqlContext, jsonPath) newDF <- mutate(df, newAge = df$age + 2) - expect_true(length(columns(newDF)) == 3) - expect_true(columns(newDF)[3] == "newAge") - expect_true(first(filter(newDF, df$name != "Michael"))$newAge == 32) + expect_equal(length(columns(newDF)), 3) + expect_equal(columns(newDF)[3], "newAge") + expect_equal(first(filter(newDF, df$name != "Michael"))$newAge, 32) newDF2 <- rename(df, newerAge = df$age) - expect_true(length(columns(newDF2)) == 2) - expect_true(columns(newDF2)[1] == "newerAge") + expect_equal(length(columns(newDF2)), 2) + expect_equal(columns(newDF2)[1], "newerAge") }) test_that("write.df() on DataFrame and works with parquetFile", { df <- jsonFile(sqlContext, jsonPath) write.df(df, parquetPath, "parquet", mode="overwrite") parquetDF <- parquetFile(sqlContext, parquetPath) - expect_true(inherits(parquetDF, "DataFrame")) + expect_is(parquetDF, "DataFrame") expect_equal(count(df), count(parquetDF)) }) @@ -828,8 +828,8 @@ test_that("parquetFile works with multiple input paths", { parquetPath2 <- tempfile(pattern = "parquetPath2", fileext = ".parquet") write.df(df, parquetPath2, "parquet", mode="overwrite") parquetDF <- parquetFile(sqlContext, parquetPath, parquetPath2) - expect_true(inherits(parquetDF, "DataFrame")) - expect_true(count(parquetDF) == count(df)*2) + expect_is(parquetDF, "DataFrame") + expect_equal(count(parquetDF), count(df)*2) }) test_that("describe() on a DataFrame", { @@ -851,58 +851,58 @@ test_that("dropna() on a DataFrame", { expected <- rows[!is.na(rows$name),] actual <- collect(dropna(df, cols = "name")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age),] actual <- collect(dropna(df, cols = "age")) row.names(expected) <- row.names(actual) # identical on two dataframes does not work here. Don't know why. # use identical on all columns as a workaround. - expect_true(identical(expected$age, actual$age)) - expect_true(identical(expected$height, actual$height)) - expect_true(identical(expected$name, actual$name)) + expect_identical(expected$age, actual$age) + expect_identical(expected$height, actual$height) + expect_identical(expected$name, actual$name) expected <- rows[!is.na(rows$age) & !is.na(rows$height),] actual <- collect(dropna(df, cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df)) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) # drop with how expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df)) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) | !is.na(rows$height) | !is.na(rows$name),] actual <- collect(dropna(df, "all")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height) & !is.na(rows$name),] actual <- collect(dropna(df, "any")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) & !is.na(rows$height),] actual <- collect(dropna(df, "any", cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[!is.na(rows$age) | !is.na(rows$height),] actual <- collect(dropna(df, "all", cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) # drop with threshold expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) >= 2,] actual <- collect(dropna(df, minNonNulls = 2, cols = c("age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows[as.integer(!is.na(rows$age)) + as.integer(!is.na(rows$height)) + as.integer(!is.na(rows$name)) >= 3,] actual <- collect(dropna(df, minNonNulls = 3, cols = c("name", "age", "height"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) }) test_that("fillna() on a DataFrame", { @@ -915,22 +915,22 @@ test_that("fillna() on a DataFrame", { expected$age[is.na(expected$age)] <- 50 expected$height[is.na(expected$height)] <- 50.6 actual <- collect(fillna(df, 50.6)) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, "unknown")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows expected$age[is.na(expected$age)] <- 50 actual <- collect(fillna(df, 50.6, "age")) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) expected <- rows expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, "unknown", c("age", "name"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) # fill with named list @@ -939,7 +939,7 @@ test_that("fillna() on a DataFrame", { expected$height[is.na(expected$height)] <- 50.6 expected$name[is.na(expected$name)] <- "unknown" actual <- collect(fillna(df, list("age" = 50, "height" = 50.6, "name" = "unknown"))) - expect_true(identical(expected, actual)) + expect_identical(expected, actual) }) unlink(parquetPath) diff --git a/R/pkg/inst/tests/test_take.R b/R/pkg/inst/tests/test_take.R index c5eb417b40159..c2c724cdc762f 100644 --- a/R/pkg/inst/tests/test_take.R +++ b/R/pkg/inst/tests/test_take.R @@ -59,8 +59,8 @@ test_that("take() gives back the original elements in correct count and order", expect_equal(take(strListRDD, 3), as.list(head(strList, n = 3))) expect_equal(take(strListRDD2, 1), as.list(head(strList, n = 1))) - expect_true(length(take(strListRDD, 0)) == 0) - expect_true(length(take(strVectorRDD, 0)) == 0) - expect_true(length(take(numListRDD, 0)) == 0) - expect_true(length(take(numVectorRDD, 0)) == 0) + expect_equal(length(take(strListRDD, 0)), 0) + expect_equal(length(take(strVectorRDD, 0)), 0) + expect_equal(length(take(numListRDD, 0)), 0) + expect_equal(length(take(numVectorRDD, 0)), 0) }) diff --git a/R/pkg/inst/tests/test_textFile.R b/R/pkg/inst/tests/test_textFile.R index 092ad9dc10c2e..58318dfef71ab 100644 --- a/R/pkg/inst/tests/test_textFile.R +++ b/R/pkg/inst/tests/test_textFile.R @@ -27,9 +27,9 @@ test_that("textFile() on a local file returns an RDD", { writeLines(mockFile, fileName) rdd <- textFile(sc, fileName) - expect_true(inherits(rdd, "RDD")) + expect_is(rdd, "RDD") expect_true(count(rdd) > 0) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName) }) @@ -133,7 +133,7 @@ test_that("textFile() on multiple paths", { writeLines("Spark is awesome.", fileName2) rdd <- textFile(sc, c(fileName1, fileName2)) - expect_true(count(rdd) == 2) + expect_equal(count(rdd), 2) unlink(fileName1) unlink(fileName2) diff --git a/R/pkg/inst/tests/test_utils.R b/R/pkg/inst/tests/test_utils.R index 15030e6f1d77e..aa0d2a66b9082 100644 --- a/R/pkg/inst/tests/test_utils.R +++ b/R/pkg/inst/tests/test_utils.R @@ -45,10 +45,10 @@ test_that("serializeToBytes on RDD", { writeLines(mockFile, fileName) text.rdd <- textFile(sc, fileName) - expect_true(getSerializedMode(text.rdd) == "string") + expect_equal(getSerializedMode(text.rdd), "string") ser.rdd <- serializeToBytes(text.rdd) expect_equal(collect(ser.rdd), as.list(mockFile)) - expect_true(getSerializedMode(ser.rdd) == "byte") + expect_equal(getSerializedMode(ser.rdd), "byte") unlink(fileName) }) From 4137f769b84300648ad933b0b3054d69a7316745 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Wed, 1 Jul 2015 10:30:54 -0700 Subject: [PATCH 120/122] [SPARK-8752][SQL] Add ExpectsInputTypes trait for defining expected input types. This patch doesn't actually introduce any code that uses the new ExpectsInputTypes. It just adds the trait so others can use it. Also renamed the old expectsInputTypes function to just inputTypes. We should add implicit type casting also in the future. Author: Reynold Xin Closes #7151 from rxin/expects-input-types and squashes the following commits: 16cf07b [Reynold Xin] [SPARK-8752][SQL] Add ExpectsInputTypes trait for defining expected input types. --- .../sql/catalyst/analysis/CheckAnalysis.scala | 1 - .../catalyst/analysis/HiveTypeCoercion.scala | 8 ++--- .../sql/catalyst/expressions/Expression.scala | 29 ++++++++++++++++--- .../spark/sql/catalyst/expressions/math.scala | 6 ++-- .../spark/sql/catalyst/expressions/misc.scala | 8 ++--- .../sql/catalyst/expressions/predicates.scala | 6 ++-- .../expressions/stringOperations.scala | 10 +++---- 7 files changed, 44 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index a069b4710f38c..583338da57117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.types._ * Throws user facing errors when passed invalid queries that fail to analyze. */ trait CheckAnalysis { - self: Analyzer => /** * Override to provide additional checks for correct analysis. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index a9d396d1faeeb..2ab5cb666fbcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -45,7 +45,7 @@ object HiveTypeCoercion { IfCoercion :: Division :: PropagateTypes :: - AddCastForAutoCastInputTypes :: + ImplicitTypeCasts :: Nil // See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types. @@ -705,13 +705,13 @@ object HiveTypeCoercion { * Casts types according to the expected input types for Expressions that have the trait * [[AutoCastInputTypes]]. */ - object AddCastForAutoCastInputTypes extends Rule[LogicalPlan] { + object ImplicitTypeCasts extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case e: AutoCastInputTypes if e.children.map(_.dataType) != e.expectedChildTypes => - val newC = (e.children, e.children.map(_.dataType), e.expectedChildTypes).zipped.map { + case e: AutoCastInputTypes if e.children.map(_.dataType) != e.inputTypes => + val newC = (e.children, e.children.map(_.dataType), e.inputTypes).zipped.map { case (child, actual, expected) => if (actual == expected) child else Cast(child, expected) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index b5063f32fa529..e18a3118945e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -265,17 +265,38 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio } } +/** + * An trait that gets mixin to define the expected input types of an expression. + */ +trait ExpectsInputTypes { self: Expression => + + /** + * Expected input types from child expressions. The i-th position in the returned seq indicates + * the type requirement for the i-th child. + * + * The possible values at each position are: + * 1. a specific data type, e.g. LongType, StringType. + * 2. a non-leaf data type, e.g. NumericType, IntegralType, FractionalType. + * 3. a list of specific data types, e.g. Seq(StringType, BinaryType). + */ + def inputTypes: Seq[Any] + + override def checkInputDataTypes(): TypeCheckResult = { + // We will do the type checking in `HiveTypeCoercion`, so always returning success here. + TypeCheckResult.TypeCheckSuccess + } +} + /** * Expressions that require a specific `DataType` as input should implement this trait * so that the proper type conversions can be performed in the analyzer. */ -trait AutoCastInputTypes { - self: Expression => +trait AutoCastInputTypes { self: Expression => - def expectedChildTypes: Seq[DataType] + def inputTypes: Seq[DataType] override def checkInputDataTypes(): TypeCheckResult = { - // We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`, + // We will always do type casting for `AutoCastInputTypes` in `HiveTypeCoercion`, // so type mismatch error won't be reported here, but for underling `Cast`s. TypeCheckResult.TypeCheckSuccess } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index da63f2fa970cf..b51318dd5044c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -59,7 +59,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) extends UnaryExpression with Serializable with AutoCastInputTypes { self: Product => - override def expectedChildTypes: Seq[DataType] = Seq(DoubleType) + override def inputTypes: Seq[DataType] = Seq(DoubleType) override def dataType: DataType = DoubleType override def nullable: Boolean = true override def toString: String = s"$name($child)" @@ -98,7 +98,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String) abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) extends BinaryExpression with Serializable with AutoCastInputTypes { self: Product => - override def expectedChildTypes: Seq[DataType] = Seq(DoubleType, DoubleType) + override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType) override def toString: String = s"$name($left, $right)" @@ -210,7 +210,7 @@ case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadia case class Bin(child: Expression) extends UnaryExpression with Serializable with AutoCastInputTypes { - override def expectedChildTypes: Seq[DataType] = Seq(LongType) + override def inputTypes: Seq[DataType] = Seq(LongType) override def dataType: DataType = StringType override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index a7bcbe46c339a..407023e472081 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -36,7 +36,7 @@ case class Md5(child: Expression) override def dataType: DataType = StringType - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + override def inputTypes: Seq[DataType] = Seq(BinaryType) override def eval(input: InternalRow): Any = { val value = child.eval(input) @@ -68,7 +68,7 @@ case class Sha2(left: Expression, right: Expression) override def toString: String = s"SHA2($left, $right)" - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType, IntegerType) + override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) override def eval(input: InternalRow): Any = { val evalE1 = left.eval(input) @@ -151,7 +151,7 @@ case class Sha1(child: Expression) extends UnaryExpression with AutoCastInputTyp override def dataType: DataType = StringType - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + override def inputTypes: Seq[DataType] = Seq(BinaryType) override def eval(input: InternalRow): Any = { val value = child.eval(input) @@ -179,7 +179,7 @@ case class Crc32(child: Expression) override def dataType: DataType = LongType - override def expectedChildTypes: Seq[DataType] = Seq(BinaryType) + override def inputTypes: Seq[DataType] = Seq(BinaryType) override def eval(input: InternalRow): Any = { val value = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 98cd5aa8148c4..a777f77add2db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -72,7 +72,7 @@ trait PredicateHelper { case class Not(child: Expression) extends UnaryExpression with Predicate with AutoCastInputTypes { override def toString: String = s"NOT $child" - override def expectedChildTypes: Seq[DataType] = Seq(BooleanType) + override def inputTypes: Seq[DataType] = Seq(BooleanType) override def eval(input: InternalRow): Any = { child.eval(input) match { @@ -122,7 +122,7 @@ case class InSet(value: Expression, hset: Set[Any]) case class And(left: Expression, right: Expression) extends BinaryExpression with Predicate with AutoCastInputTypes { - override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) override def symbol: String = "&&" @@ -171,7 +171,7 @@ case class And(left: Expression, right: Expression) case class Or(left: Expression, right: Expression) extends BinaryExpression with Predicate with AutoCastInputTypes { - override def expectedChildTypes: Seq[DataType] = Seq(BooleanType, BooleanType) + override def inputTypes: Seq[DataType] = Seq(BooleanType, BooleanType) override def symbol: String = "||" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index ce184e4f32f18..4cbfc4e084948 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -32,7 +32,7 @@ trait StringRegexExpression extends AutoCastInputTypes { override def nullable: Boolean = left.nullable || right.nullable override def dataType: DataType = BooleanType - override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) // try cache the pattern for Literal private lazy val cache: Pattern = right match { @@ -117,7 +117,7 @@ trait CaseConversionExpression extends AutoCastInputTypes { def convert(v: UTF8String): UTF8String override def dataType: DataType = StringType - override def expectedChildTypes: Seq[DataType] = Seq(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType) override def eval(input: InternalRow): Any = { val evaluated = child.eval(input) @@ -165,7 +165,7 @@ trait StringComparison extends AutoCastInputTypes { override def nullable: Boolean = left.nullable || right.nullable - override def expectedChildTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) override def eval(input: InternalRow): Any = { val leftEval = left.eval(input) @@ -238,7 +238,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) if (str.dataType == BinaryType) str.dataType else StringType } - override def expectedChildTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) + override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) override def children: Seq[Expression] = str :: pos :: len :: Nil @@ -297,7 +297,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) */ case class StringLength(child: Expression) extends UnaryExpression with AutoCastInputTypes { override def dataType: DataType = IntegerType - override def expectedChildTypes: Seq[DataType] = Seq(StringType) + override def inputTypes: Seq[DataType] = Seq(StringType) override def eval(input: InternalRow): Any = { val string = child.eval(input) From 31b4a3d7f2be9053a041e5ae67418562a93d80d8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 1 Jul 2015 10:31:35 -0700 Subject: [PATCH 121/122] [SPARK-8621] [SQL] support empty string as column name improve the empty check in `parseAttributeName` so that we can allow empty string as column name. Close https://github.com/apache/spark/pull/7117 Author: Wenchen Fan Closes #7149 from cloud-fan/8621 and squashes the following commits: efa9e3e [Wenchen Fan] support empty string --- .../spark/sql/catalyst/plans/logical/LogicalPlan.scala | 4 ++-- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index b009a200b920f..e911b907e8536 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -161,7 +161,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { if (tmp.nonEmpty) throw e inBacktick = true } else if (char == '.') { - if (tmp.isEmpty) throw e + if (name(i - 1) == '.' || i == name.length - 1) throw e nameParts += tmp.mkString tmp.clear() } else { @@ -170,7 +170,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { } i += 1 } - if (tmp.isEmpty || inBacktick) throw e + if (inBacktick) throw e nameParts += tmp.mkString nameParts.toSeq } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 50d324c0686fa..afb1cf5f8d1cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -730,4 +730,11 @@ class DataFrameSuite extends QueryTest { val res11 = ctx.range(-1).select("id") assert(res11.count == 0) } + + test("SPARK-8621: support empty string column name") { + val df = Seq(Tuple1(1)).toDF("").as("t") + // We should allow empty string as column name + df.col("") + df.col("t.``") + } } From 184de91d15a4bfc5c014e8cf86211874bba4593f Mon Sep 17 00:00:00 2001 From: lewuathe Date: Wed, 1 Jul 2015 11:14:07 -0700 Subject: [PATCH 122/122] [SPARK-6263] [MLLIB] Python MLlib API missing items: Utils Implement missing API in pyspark. MLUtils * appendBias * loadVectors `kFold` is also missing however I am not sure `ClassTag` can be passed or restored through python. Author: lewuathe Closes #5707 from Lewuathe/SPARK-6263 and squashes the following commits: 16863ea [lewuathe] Merge master 3fc27e7 [lewuathe] Merge branch 'master' into SPARK-6263 6084e9c [lewuathe] Resolv conflict d2aa2a0 [lewuathe] Resolv conflict 9c329d8 [lewuathe] Fix efficiency 3a12a2d [lewuathe] Merge branch 'master' into SPARK-6263 1d4714b [lewuathe] Fix style b29e2bc [lewuathe] Remove scipy dependencies e32eb40 [lewuathe] Merge branch 'master' into SPARK-6263 25d3c9d [lewuathe] Remove unnecessary imports 7ec04db [lewuathe] Resolv conflict 1502d13 [lewuathe] Resolv conflict d6bd416 [lewuathe] Check existence of scipy.sparse 5d555b1 [lewuathe] Construct scipy.sparse matrix c345a44 [lewuathe] Merge branch 'master' into SPARK-6263 b8b5ef7 [lewuathe] Fix unnecessary sort method d254be7 [lewuathe] Merge branch 'master' into SPARK-6263 62a9c7e [lewuathe] Fix appendBias return type 454c73d [lewuathe] Merge branch 'master' into SPARK-6263 a353354 [lewuathe] Remove unnecessary appendBias implementation 44295c2 [lewuathe] Merge branch 'master' into SPARK-6263 64f72ad [lewuathe] Merge branch 'master' into SPARK-6263 c728046 [lewuathe] Fix style 2980569 [lewuathe] [SPARK-6263] Python MLlib API missing items: Utils --- .../mllib/api/python/PythonMLLibAPI.scala | 9 ++++ python/pyspark/mllib/tests.py | 43 +++++++++++++++++++ python/pyspark/mllib/util.py | 22 ++++++++++ 3 files changed, 74 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index a66a404d5c846..458fab48fef5a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -75,6 +75,15 @@ private[python] class PythonMLLibAPI extends Serializable { minPartitions: Int): JavaRDD[LabeledPoint] = MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions) + /** + * Loads and serializes vectors saved with `RDD#saveAsTextFile`. + * @param jsc Java SparkContext + * @param path file or directory path in any Hadoop-supported file system URI + * @return serialized vectors in a RDD + */ + def loadVectors(jsc: JavaSparkContext, path: String): RDD[Vector] = + MLUtils.loadVectors(jsc.sc, path) + private def trainRegressionModel( learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel], data: JavaRDD[LabeledPoint], diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f0091d6faccce..49ce125de7e78 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -54,6 +54,7 @@ from pyspark.mllib.feature import IDF from pyspark.mllib.feature import StandardScaler, ElementwiseProduct from pyspark.mllib.util import LinearDataGenerator +from pyspark.mllib.util import MLUtils from pyspark.serializers import PickleSerializer from pyspark.streaming import StreamingContext from pyspark.sql import SQLContext @@ -1290,6 +1291,48 @@ def func(rdd): self.assertTrue(mean_absolute_errors[1] - mean_absolute_errors[-1] > 2) +class MLUtilsTests(MLlibTestCase): + def test_append_bias(self): + data = [2.0, 2.0, 2.0] + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + self.assertEqual(type(ret), DenseVector) + + def test_append_bias_with_vector(self): + data = Vectors.dense([2.0, 2.0, 2.0]) + ret = MLUtils.appendBias(data) + self.assertEqual(ret[3], 1.0) + self.assertEqual(type(ret), DenseVector) + + def test_append_bias_with_sp_vector(self): + data = Vectors.sparse(3, {0: 2.0, 2: 2.0}) + expected = Vectors.sparse(4, {0: 2.0, 2: 2.0, 3: 1.0}) + # Returned value must be SparseVector + ret = MLUtils.appendBias(data) + self.assertEqual(ret, expected) + self.assertEqual(type(ret), SparseVector) + + def test_load_vectors(self): + import shutil + data = [ + [1.0, 2.0, 3.0], + [1.0, 2.0, 3.0] + ] + temp_dir = tempfile.mkdtemp() + load_vectors_path = os.path.join(temp_dir, "test_load_vectors") + try: + self.sc.parallelize(data).saveAsTextFile(load_vectors_path) + ret_rdd = MLUtils.loadVectors(self.sc, load_vectors_path) + ret = ret_rdd.collect() + self.assertEqual(len(ret), 2) + self.assertEqual(ret[0], DenseVector([1.0, 2.0, 3.0])) + self.assertEqual(ret[1], DenseVector([1.0, 2.0, 3.0])) + except: + self.fail() + finally: + shutil.rmtree(load_vectors_path) + + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 348238319e407..875d3b2d642c6 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -169,6 +169,28 @@ def loadLabeledPoints(sc, path, minPartitions=None): minPartitions = minPartitions or min(sc.defaultParallelism, 2) return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions) + @staticmethod + def appendBias(data): + """ + Returns a new vector with `1.0` (bias) appended to + the end of the input vector. + """ + vec = _convert_to_vector(data) + if isinstance(vec, SparseVector): + newIndices = np.append(vec.indices, len(vec)) + newValues = np.append(vec.values, 1.0) + return SparseVector(len(vec) + 1, newIndices, newValues) + else: + return _convert_to_vector(np.append(vec.toArray(), 1.0)) + + @staticmethod + def loadVectors(sc, path): + """ + Loads vectors saved using `RDD[Vector].saveAsTextFile` + with the default number of partitions. + """ + return callMLlibFunc("loadVectors", sc, path) + class Saveable(object): """