From 66bb8003b949860b8652542e1232bc48665448c2 Mon Sep 17 00:00:00 2001 From: Carson Wang Date: Mon, 20 Jul 2015 18:08:59 -0700 Subject: [PATCH 01/32] [SPARK-9187] [WEBUI] Timeline view may show negative value for running tasks For running tasks, the executorRunTime metrics is 0 which causes negative executorComputingTime in the timeline. It also causes an incorrect SchedulerDelay time. ![timelinenegativevalue](https://cloud.githubusercontent.com/assets/9278199/8770953/f4362378-2eec-11e5-81e6-a06a07c04794.png) Author: Carson Wang Closes #7526 from carsonwang/timeline-negValue and squashes the following commits: 7b17db2 [Carson Wang] Fix negative value in timeline view --- .../org/apache/spark/ui/jobs/StagePage.scala | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 27b82aaddd2e4..6e077bf3e70d5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -537,20 +537,27 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { (metricsOpt.flatMap(_.shuffleWriteMetrics .map(_.shuffleWriteTime)).getOrElse(0L) / 1e6).toLong val shuffleWriteTimeProportion = toProportion(shuffleWriteTime) - val executorComputingTime = metricsOpt.map(_.executorRunTime).getOrElse(0L) - - shuffleReadTime - shuffleWriteTime - val executorComputingTimeProportion = toProportion(executorComputingTime) + val serializationTime = metricsOpt.map(_.resultSerializationTime).getOrElse(0L) val serializationTimeProportion = toProportion(serializationTime) val deserializationTime = metricsOpt.map(_.executorDeserializeTime).getOrElse(0L) val deserializationTimeProportion = toProportion(deserializationTime) val gettingResultTime = getGettingResultTime(taskUIData.taskInfo, currentTime) val gettingResultTimeProportion = toProportion(gettingResultTime) - val schedulerDelay = totalExecutionTime - - (executorComputingTime + shuffleReadTime + shuffleWriteTime + - serializationTime + deserializationTime + gettingResultTime) - val schedulerDelayProportion = - (100 - executorComputingTimeProportion - shuffleReadTimeProportion - + val schedulerDelay = + metricsOpt.map(getSchedulerDelay(taskInfo, _, currentTime)).getOrElse(0L) + val schedulerDelayProportion = toProportion(schedulerDelay) + + val executorOverhead = serializationTime + deserializationTime + val executorRunTime = if (taskInfo.running) { + totalExecutionTime - executorOverhead - gettingResultTime + } else { + metricsOpt.map(_.executorRunTime).getOrElse( + totalExecutionTime - executorOverhead - gettingResultTime) + } + val executorComputingTime = executorRunTime - shuffleReadTime - shuffleWriteTime + val executorComputingTimeProportion = + (100 - schedulerDelayProportion - shuffleReadTimeProportion - shuffleWriteTimeProportion - serializationTimeProportion - deserializationTimeProportion - gettingResultTimeProportion) From 047ccc8c9a88e74f7bc87709ee5d531f1d7a4228 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 20 Jul 2015 18:16:49 -0700 Subject: [PATCH 02/32] [SPARK-9178][SQL] Add an empty string constant to UTF8String Jira: https://issues.apache.org/jira/browse/SPARK-9178 In order to avoid calls of `UTF8String.fromString("")` this pr adds an `EMPTY_STRING` constant to `UTF8String`. An `UTF8String` is immutable, so we can use a constant, isn't it? I searched for current usage of `UTF8String.fromString("")` with `grep -R "UTF8String.fromString(\"\")" .` Author: Tarek Auel Closes #7509 from tarekauel/SPARK-9178 and squashes the following commits: 8d6c405 [Tarek Auel] [SPARK-9178] revert intellij indents 3627b80 [Tarek Auel] [SPARK-9178] revert concat tests changes 3f5fbf5 [Tarek Auel] [SPARK-9178] rebase and add final to UTF8String.EMPTY_UTF8 47cda68 [Tarek Auel] Merge branch 'master' into SPARK-9178 4a37344 [Tarek Auel] [SPARK-9178] changed name to EMPTY_UTF8, added tests 748b87a [Tarek Auel] [SPARK-9178] Add empty string constant to UTF8String --- .../apache/spark/unsafe/types/UTF8String.java | 2 + .../spark/unsafe/types/UTF8StringSuite.java | 76 +++++++++---------- 2 files changed, 39 insertions(+), 39 deletions(-) diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 819639f300177..fc63fe537d226 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -50,6 +50,8 @@ public final class UTF8String implements Comparable, Serializable { 5, 5, 5, 5, 6, 6, 6, 6}; + public static final UTF8String EMPTY_UTF8 = UTF8String.fromString(""); + /** * Creates an UTF8String from byte array, which should be encoded in UTF-8. * diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 6a21c27461163..d730b1d1384f5 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -54,6 +54,14 @@ public void basicTest() throws UnsupportedEncodingException { checkBasic("大 千 世 界", 7); } + @Test + public void emptyStringTest() { + assertEquals(fromString(""), EMPTY_UTF8); + assertEquals(fromBytes(new byte[0]), EMPTY_UTF8); + assertEquals(0, EMPTY_UTF8.numChars()); + assertEquals(0, EMPTY_UTF8.numBytes()); + } + @Test public void compareTo() { assertTrue(fromString("abc").compareTo(fromString("ABC")) > 0); @@ -88,9 +96,9 @@ public void upperAndLower() { @Test public void concatTest() { - assertEquals(fromString(""), concat()); + assertEquals(EMPTY_UTF8, concat()); assertEquals(null, concat((UTF8String) null)); - assertEquals(fromString(""), concat(fromString(""))); + assertEquals(EMPTY_UTF8, concat(EMPTY_UTF8)); assertEquals(fromString("ab"), concat(fromString("ab"))); assertEquals(fromString("ab"), concat(fromString("a"), fromString("b"))); assertEquals(fromString("abc"), concat(fromString("a"), fromString("b"), fromString("c"))); @@ -109,8 +117,8 @@ public void concatWsTest() { // If separator is null, concatWs should skip all null inputs and never return null. UTF8String sep = fromString("哈哈"); assertEquals( - fromString(""), - concatWs(sep, fromString(""))); + EMPTY_UTF8, + concatWs(sep, EMPTY_UTF8)); assertEquals( fromString("ab"), concatWs(sep, fromString("ab"))); @@ -127,7 +135,7 @@ public void concatWsTest() { fromString("a"), concatWs(sep, fromString("a"), null, null)); assertEquals( - fromString(""), + EMPTY_UTF8, concatWs(sep, null, null, null)); assertEquals( fromString("数据哈哈砖头"), @@ -136,7 +144,7 @@ public void concatWsTest() { @Test public void contains() { - assertTrue(fromString("").contains(fromString(""))); + assertTrue(EMPTY_UTF8.contains(EMPTY_UTF8)); assertTrue(fromString("hello").contains(fromString("ello"))); assertFalse(fromString("hello").contains(fromString("vello"))); assertFalse(fromString("hello").contains(fromString("hellooo"))); @@ -147,7 +155,7 @@ public void contains() { @Test public void startsWith() { - assertTrue(fromString("").startsWith(fromString(""))); + assertTrue(EMPTY_UTF8.startsWith(EMPTY_UTF8)); assertTrue(fromString("hello").startsWith(fromString("hell"))); assertFalse(fromString("hello").startsWith(fromString("ell"))); assertFalse(fromString("hello").startsWith(fromString("hellooo"))); @@ -158,7 +166,7 @@ public void startsWith() { @Test public void endsWith() { - assertTrue(fromString("").endsWith(fromString(""))); + assertTrue(EMPTY_UTF8.endsWith(EMPTY_UTF8)); assertTrue(fromString("hello").endsWith(fromString("ello"))); assertFalse(fromString("hello").endsWith(fromString("ellov"))); assertFalse(fromString("hello").endsWith(fromString("hhhello"))); @@ -169,7 +177,7 @@ public void endsWith() { @Test public void substring() { - assertEquals(fromString(""), fromString("hello").substring(0, 0)); + assertEquals(EMPTY_UTF8, fromString("hello").substring(0, 0)); assertEquals(fromString("el"), fromString("hello").substring(1, 3)); assertEquals(fromString("数"), fromString("数据砖头").substring(0, 1)); assertEquals(fromString("据砖"), fromString("数据砖头").substring(1, 3)); @@ -183,9 +191,9 @@ public void trims() { assertEquals(fromString("hello "), fromString(" hello ").trimLeft()); assertEquals(fromString(" hello"), fromString(" hello ").trimRight()); - assertEquals(fromString(""), fromString(" ").trim()); - assertEquals(fromString(""), fromString(" ").trimLeft()); - assertEquals(fromString(""), fromString(" ").trimRight()); + assertEquals(EMPTY_UTF8, fromString(" ").trim()); + assertEquals(EMPTY_UTF8, fromString(" ").trimLeft()); + assertEquals(EMPTY_UTF8, fromString(" ").trimRight()); assertEquals(fromString("数据砖头"), fromString(" 数据砖头 ").trim()); assertEquals(fromString("数据砖头 "), fromString(" 数据砖头 ").trimLeft()); @@ -198,9 +206,9 @@ public void trims() { @Test public void indexOf() { - assertEquals(0, fromString("").indexOf(fromString(""), 0)); - assertEquals(-1, fromString("").indexOf(fromString("l"), 0)); - assertEquals(0, fromString("hello").indexOf(fromString(""), 0)); + assertEquals(0, EMPTY_UTF8.indexOf(EMPTY_UTF8, 0)); + assertEquals(-1, EMPTY_UTF8.indexOf(fromString("l"), 0)); + assertEquals(0, fromString("hello").indexOf(EMPTY_UTF8, 0)); assertEquals(2, fromString("hello").indexOf(fromString("l"), 0)); assertEquals(3, fromString("hello").indexOf(fromString("l"), 3)); assertEquals(-1, fromString("hello").indexOf(fromString("a"), 0)); @@ -215,7 +223,7 @@ public void indexOf() { @Test public void reverse() { assertEquals(fromString("olleh"), fromString("hello").reverse()); - assertEquals(fromString(""), fromString("").reverse()); + assertEquals(EMPTY_UTF8, EMPTY_UTF8.reverse()); assertEquals(fromString("者行孙"), fromString("孙行者").reverse()); assertEquals(fromString("者行孙 olleh"), fromString("hello 孙行者").reverse()); } @@ -224,7 +232,7 @@ public void reverse() { public void repeat() { assertEquals(fromString("数d数d数d数d数d"), fromString("数d").repeat(5)); assertEquals(fromString("数d"), fromString("数d").repeat(1)); - assertEquals(fromString(""), fromString("数d").repeat(-1)); + assertEquals(EMPTY_UTF8, fromString("数d").repeat(-1)); } @Test @@ -234,14 +242,14 @@ public void pad() { assertEquals(fromString("?hello"), fromString("hello").lpad(6, fromString("????"))); assertEquals(fromString("???????hello"), fromString("hello").lpad(12, fromString("????"))); assertEquals(fromString("?????hello"), fromString("hello").lpad(10, fromString("?????"))); - assertEquals(fromString("???????"), fromString("").lpad(7, fromString("?????"))); + assertEquals(fromString("???????"), EMPTY_UTF8.lpad(7, fromString("?????"))); assertEquals(fromString("hel"), fromString("hello").rpad(3, fromString("????"))); assertEquals(fromString("hello"), fromString("hello").rpad(5, fromString("????"))); assertEquals(fromString("hello?"), fromString("hello").rpad(6, fromString("????"))); assertEquals(fromString("hello???????"), fromString("hello").rpad(12, fromString("????"))); assertEquals(fromString("hello?????"), fromString("hello").rpad(10, fromString("?????"))); - assertEquals(fromString("???????"), fromString("").rpad(7, fromString("?????"))); + assertEquals(fromString("???????"), EMPTY_UTF8.rpad(7, fromString("?????"))); assertEquals(fromString("数据砖"), fromString("数据砖头").lpad(3, fromString("????"))); @@ -265,26 +273,16 @@ public void pad() { @Test public void levenshteinDistance() { - assertEquals( - UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("")), 0); - assertEquals( - UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("a")), 1); - assertEquals( - UTF8String.fromString("aaapppp").levenshteinDistance(UTF8String.fromString("")), 7); - assertEquals( - UTF8String.fromString("frog").levenshteinDistance(UTF8String.fromString("fog")), 1); - assertEquals( - UTF8String.fromString("fly").levenshteinDistance(UTF8String.fromString("ant")),3); - assertEquals( - UTF8String.fromString("elephant").levenshteinDistance(UTF8String.fromString("hippo")), 7); - assertEquals( - UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("elephant")), 7); - assertEquals( - UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("zzzzzzzz")), 8); - assertEquals( - UTF8String.fromString("hello").levenshteinDistance(UTF8String.fromString("hallo")),1); - assertEquals( - UTF8String.fromString("世界千世").levenshteinDistance(UTF8String.fromString("千a世b")),4); + assertEquals(EMPTY_UTF8.levenshteinDistance(EMPTY_UTF8), 0); + assertEquals(EMPTY_UTF8.levenshteinDistance(fromString("a")), 1); + assertEquals(fromString("aaapppp").levenshteinDistance(EMPTY_UTF8), 7); + assertEquals(fromString("frog").levenshteinDistance(fromString("fog")), 1); + assertEquals(fromString("fly").levenshteinDistance(fromString("ant")),3); + assertEquals(fromString("elephant").levenshteinDistance(fromString("hippo")), 7); + assertEquals(fromString("hippo").levenshteinDistance(fromString("elephant")), 7); + assertEquals(fromString("hippo").levenshteinDistance(fromString("zzzzzzzz")), 8); + assertEquals(fromString("hello").levenshteinDistance(fromString("hallo")),1); + assertEquals(fromString("世界千世").levenshteinDistance(fromString("千a世b")),4); } @Test From 6853ac7c8c76003160fc861ddcc8e8e39e4a5924 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 20 Jul 2015 18:21:05 -0700 Subject: [PATCH 03/32] [SPARK-9156][SQL] codegen StringSplit Jira: https://issues.apache.org/jira/browse/SPARK-9156 Author: Tarek Auel Closes #7547 from tarekauel/SPARK-9156 and squashes the following commits: 0be2700 [Tarek Auel] [SPARK-9156][SQL] indention fix b860eaf [Tarek Auel] [SPARK-9156][SQL] codegen StringSplit 5ad6a1f [Tarek Auel] [SPARK-9156] codegen StringSplit --- .../sql/catalyst/expressions/stringOperations.scala | 12 ++++++++---- .../org/apache/spark/unsafe/types/UTF8String.java | 9 +++++++++ .../apache/spark/unsafe/types/UTF8StringSuite.java | 11 +++++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index a5682428b3d40..5c1908d55576a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -615,7 +615,7 @@ case class StringSpace(child: Expression) * Splits str around pat (pattern is a regular expression). */ case class StringSplit(str: Expression, pattern: Expression) - extends BinaryExpression with ImplicitCastInputTypes with CodegenFallback { + extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = str override def right: Expression = pattern @@ -623,9 +623,13 @@ case class StringSplit(str: Expression, pattern: Expression) override def inputTypes: Seq[DataType] = Seq(StringType, StringType) override def nullSafeEval(string: Any, regex: Any): Any = { - val splits = - string.asInstanceOf[UTF8String].toString.split(regex.asInstanceOf[UTF8String].toString, -1) - splits.toSeq.map(UTF8String.fromString) + string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1).toSeq + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (str, pattern) => + s"""${ev.primitive} = scala.collection.JavaConversions.asScalaBuffer( + java.util.Arrays.asList($str.split($pattern, -1)));""") } override def prettyName: String = "split" diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index fc63fe537d226..ed354f7f877f1 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -487,6 +487,15 @@ public static UTF8String concatWs(UTF8String separator, UTF8String... inputs) { return fromBytes(result); } + public UTF8String[] split(UTF8String pattern, int limit) { + String[] splits = toString().split(pattern.toString(), limit); + UTF8String[] res = new UTF8String[splits.length]; + for (int i = 0; i < res.length; i++) { + res[i] = fromString(splits[i]); + } + return res; + } + @Override public String toString() { try { diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index d730b1d1384f5..1f5572c509bdb 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe.types; import java.io.UnsupportedEncodingException; +import java.util.Arrays; import org.junit.Test; @@ -270,6 +271,16 @@ public void pad() { fromString("数据砖头孙行者孙行者孙行"), fromString("数据砖头").rpad(12, fromString("孙行者"))); } + + @Test + public void split() { + assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), -1), + new UTF8String[]{fromString("ab"), fromString("def"), fromString("ghi")})); + assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2), + new UTF8String[]{fromString("ab"), fromString("def,ghi")})); + assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2), + new UTF8String[]{fromString("ab"), fromString("def,ghi")})); + } @Test public void levenshteinDistance() { From e90543e5366808332bbde18d78cccd4d064a3338 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 20 Jul 2015 18:23:51 -0700 Subject: [PATCH 04/32] [SPARK-9142][SQL] Removing unnecessary self types in expressions. Also added documentation to expressions to explain the important traits and abstract classes. Author: Reynold Xin Closes #7550 from rxin/remove-self-types and squashes the following commits: b2a3ec1 [Reynold Xin] [SPARK-9142][SQL] Removing unnecessary self types in expressions. --- .../expressions/ExpectsInputTypes.scala | 4 +-- .../sql/catalyst/expressions/Expression.scala | 33 +++++++++++-------- .../expressions/codegen/CodegenFallback.scala | 2 +- 3 files changed, 22 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index ded89e85dea79..abe6457747550 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.ImplicitTypeCasts * * Most function expressions (e.g. [[Substring]] should extends [[ImplicitCastInputTypes]]) instead. */ -trait ExpectsInputTypes { self: Expression => +trait ExpectsInputTypes extends Expression { /** * Expected input types from child expressions. The i-th position in the returned seq indicates @@ -60,6 +60,6 @@ trait ExpectsInputTypes { self: Expression => /** * A mixin for the analyzer to perform implicit type casting using [[ImplicitTypeCasts]]. */ -trait ImplicitCastInputTypes extends ExpectsInputTypes { self: Expression => +trait ImplicitCastInputTypes extends ExpectsInputTypes { // No other methods } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index da599b8963340..aada25276adb7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -19,19 +19,12 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// -// This file defines the basic expression abstract classes in Catalyst, including: -// Expression: the base expression abstract class -// LeafExpression -// UnaryExpression -// BinaryExpression -// BinaryOperator -// -// For details, see their classdocs. +// This file defines the basic expression abstract classes in Catalyst. //////////////////////////////////////////////////////////////////////////////////////////////////// /** @@ -39,9 +32,21 @@ import org.apache.spark.sql.types._ * * If an expression wants to be exposed in the function registry (so users can call it with * "name(arguments...)", the concrete implementation must be a case class whose constructor - * arguments are all Expressions types. + * arguments are all Expressions types. See [[Substring]] for an example. + * + * There are a few important traits: + * + * - [[Nondeterministic]]: an expression that is not deterministic. + * - [[Unevaluable]]: an expression that is not supposed to be evaluated. + * - [[CodegenFallback]]: an expression that does not have code gen implemented and falls back to + * interpreted mode. + * + * - [[LeafExpression]]: an expression that has no child. + * - [[UnaryExpression]]: an expression that has one child. + * - [[BinaryExpression]]: an expression that has two children. + * - [[BinaryOperator]]: a special case of [[BinaryExpression]] that requires two children to have + * the same output data type. * - * See [[Substring]] for an example. */ abstract class Expression extends TreeNode[Expression] { @@ -176,7 +181,7 @@ abstract class Expression extends TreeNode[Expression] { * An expression that cannot be evaluated. Some expressions don't live past analysis or optimization * time (e.g. Star). This trait is used by those expressions. */ -trait Unevaluable { self: Expression => +trait Unevaluable extends Expression { override def eval(input: InternalRow = null): Any = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") @@ -185,11 +190,11 @@ trait Unevaluable { self: Expression => throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") } + /** * An expression that is nondeterministic. */ -trait Nondeterministic { self: Expression => - +trait Nondeterministic extends Expression { override def deterministic: Boolean = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index bf4f600cb26e5..6b187f05604fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression /** * A trait that can be used to provide a fallback mode for expression code generation. */ -trait CodegenFallback { self: Expression => +trait CodegenFallback extends Expression { protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { ctx.references += this From 936a96cb31a6dd7d8685bce05103e779ca02e763 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 20 Jul 2015 19:17:59 -0700 Subject: [PATCH 05/32] [SPARK-9164] [SQL] codegen hex/unhex Jira: https://issues.apache.org/jira/browse/SPARK-9164 The diff looks heavy, but I just moved the `hex` and `unhex` methods to `object Hex`. This allows me to call them from `eval` and `codeGen` Author: Tarek Auel Closes #7548 from tarekauel/SPARK-9164 and squashes the following commits: dd91c57 [Tarek Auel] [SPARK-9164][SQL] codegen hex/unhex --- .../spark/sql/catalyst/expressions/math.scala | 96 +++++++++++-------- 1 file changed, 57 insertions(+), 39 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 7ce64d29ba59a..7a9be02ba45b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -489,28 +489,8 @@ object Hex { (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte) array } -} -/** - * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. - * Otherwise if the number is a STRING, it converts each character into its hex representation - * and returns the resulting STRING. Negative numbers would be treated as two's complement. - */ -case class Hex(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback { - - override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, BinaryType, StringType)) - - override def dataType: DataType = StringType - - protected override def nullSafeEval(num: Any): Any = child.dataType match { - case LongType => hex(num.asInstanceOf[Long]) - case BinaryType => hex(num.asInstanceOf[Array[Byte]]) - case StringType => hex(num.asInstanceOf[UTF8String].getBytes) - } - - private[this] def hex(bytes: Array[Byte]): UTF8String = { + def hex(bytes: Array[Byte]): UTF8String = { val length = bytes.length val value = new Array[Byte](length * 2) var i = 0 @@ -522,7 +502,7 @@ case class Hex(child: Expression) UTF8String.fromBytes(value) } - private def hex(num: Long): UTF8String = { + def hex(num: Long): UTF8String = { // Extract the hex digits of num into value[] from right to left val value = new Array[Byte](16) var numBuf = num @@ -534,24 +514,8 @@ case class Hex(child: Expression) } while (numBuf != 0) UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length)) } -} -/** - * Performs the inverse operation of HEX. - * Resulting characters are returned as a byte array. - */ -case class Unhex(child: Expression) - extends UnaryExpression with ImplicitCastInputTypes with CodegenFallback { - - override def inputTypes: Seq[AbstractDataType] = Seq(StringType) - - override def nullable: Boolean = true - override def dataType: DataType = BinaryType - - protected override def nullSafeEval(num: Any): Any = - unhex(num.asInstanceOf[UTF8String].getBytes) - - private[this] def unhex(bytes: Array[Byte]): Array[Byte] = { + def unhex(bytes: Array[Byte]): Array[Byte] = { val out = new Array[Byte]((bytes.length + 1) >> 1) var i = 0 if ((bytes.length & 0x01) != 0) { @@ -583,6 +547,60 @@ case class Unhex(child: Expression) } } +/** + * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format. + * Otherwise if the number is a STRING, it converts each character into its hex representation + * and returns the resulting STRING. Negative numbers would be treated as two's complement. + */ +case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(LongType, BinaryType, StringType)) + + override def dataType: DataType = StringType + + protected override def nullSafeEval(num: Any): Any = child.dataType match { + case LongType => Hex.hex(num.asInstanceOf[Long]) + case BinaryType => Hex.hex(num.asInstanceOf[Array[Byte]]) + case StringType => Hex.hex(num.asInstanceOf[UTF8String].getBytes) + } + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (c) => { + val hex = Hex.getClass.getName.stripSuffix("$") + s"${ev.primitive} = " + (child.dataType match { + case StringType => s"""$hex.hex($c.getBytes());""" + case _ => s"""$hex.hex($c);""" + }) + }) + } +} + +/** + * Performs the inverse operation of HEX. + * Resulting characters are returned as a byte array. + */ +case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(StringType) + + override def nullable: Boolean = true + override def dataType: DataType = BinaryType + + protected override def nullSafeEval(num: Any): Any = + Hex.unhex(num.asInstanceOf[UTF8String].getBytes) + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (c) => { + val hex = Hex.getClass.getName.stripSuffix("$") + s""" + ${ev.primitive} = $hex.unhex($c.getBytes()); + ${ev.isNull} = ${ev.primitive} == null; + """ + }) + } +} + //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// From 2bdf9914ab709bf9c1cdd17fc5dd7a69f6d46f29 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Tue, 21 Jul 2015 11:38:22 +0900 Subject: [PATCH 06/32] [SPARK-9052] [SPARKR] Fix comments after curly braces [[SPARK-9052] Fix comments after curly braces - ASF JIRA](https://issues.apache.org/jira/browse/SPARK-9052) This is the full result of lintr at the rivision:011551620faa87107a787530f074af3d9be7e695. [[SPARK-9052] the result of lint-r at the revision:011551620faa87107a787530f074af3d9be7e695](https://gist.github.com/yu-iskw/e7246041b173a3f29482) This is the difference of the result between before and after. https://gist.github.com/yu-iskw/e7246041b173a3f29482/revisions Author: Yu ISHIKAWA Closes #7440 from yu-iskw/SPARK-9052 and squashes the following commits: 015d738 [Yu ISHIKAWA] Fix the indentations and move the placement of commna 5cc30fe [Yu ISHIKAWA] Fix the indentation in a condition 4ead0e5 [Yu ISHIKAWA] [SPARK-9052][SparkR] Fix comments after curly braces --- R/pkg/R/schema.R | 13 ++++++++----- R/pkg/R/utils.R | 33 ++++++++++++++++++++++----------- 2 files changed, 30 insertions(+), 16 deletions(-) diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index 06df430687682..79c744ef29c23 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -69,11 +69,14 @@ structType.structField <- function(x, ...) { #' @param ... further arguments passed to or from other methods print.structType <- function(x, ...) { cat("StructType\n", - sapply(x$fields(), function(field) { paste("|-", "name = \"", field$name(), - "\", type = \"", field$dataType.toString(), - "\", nullable = ", field$nullable(), "\n", - sep = "") }) - , sep = "") + sapply(x$fields(), + function(field) { + paste("|-", "name = \"", field$name(), + "\", type = \"", field$dataType.toString(), + "\", nullable = ", field$nullable(), "\n", + sep = "") + }), + sep = "") } #' structField diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 950ba74dbe017..3f45589a50443 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -390,14 +390,17 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { for (i in 1:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } - } else { # if node[[1]] is length of 1, check for some R special functions. + } else { + # if node[[1]] is length of 1, check for some R special functions. nodeChar <- as.character(node[[1]]) - if (nodeChar == "{" || nodeChar == "(") { # Skip start symbol. + if (nodeChar == "{" || nodeChar == "(") { + # Skip start symbol. for (i in 2:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } } else if (nodeChar == "<-" || nodeChar == "=" || - nodeChar == "<<-") { # Assignment Ops. + nodeChar == "<<-") { + # Assignment Ops. defVar <- node[[2]] if (length(defVar) == 1 && typeof(defVar) == "symbol") { # Add the defined variable name into defVars. @@ -408,14 +411,16 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { for (i in 3:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } - } else if (nodeChar == "function") { # Function definition. + } else if (nodeChar == "function") { + # Function definition. # Add parameter names. newArgs <- names(node[[2]]) lapply(newArgs, function(arg) { addItemToAccumulator(defVars, arg) }) for (i in 3:nodeLen) { processClosure(node[[i]], oldEnv, defVars, checkedFuncs, newEnv) } - } else if (nodeChar == "$") { # Skip the field. + } else if (nodeChar == "$") { + # Skip the field. processClosure(node[[2]], oldEnv, defVars, checkedFuncs, newEnv) } else if (nodeChar == "::" || nodeChar == ":::") { processClosure(node[[3]], oldEnv, defVars, checkedFuncs, newEnv) @@ -429,7 +434,8 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { (typeof(node) == "symbol" || typeof(node) == "language")) { # Base case: current AST node is a leaf node and a symbol or a function call. nodeChar <- as.character(node) - if (!nodeChar %in% defVars$data) { # Not a function parameter or local variable. + if (!nodeChar %in% defVars$data) { + # Not a function parameter or local variable. func.env <- oldEnv topEnv <- parent.env(.GlobalEnv) # Search in function environment, and function's enclosing environments @@ -439,20 +445,24 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) { while (!identical(func.env, topEnv)) { # Namespaces other than "SparkR" will not be searched. if (!isNamespace(func.env) || - (getNamespaceName(func.env) == "SparkR" && - !(nodeChar %in% getNamespaceExports("SparkR")))) { # Only include SparkR internals. + (getNamespaceName(func.env) == "SparkR" && + !(nodeChar %in% getNamespaceExports("SparkR")))) { + # Only include SparkR internals. + # Set parameter 'inherits' to FALSE since we do not need to search in # attached package environments. if (tryCatch(exists(nodeChar, envir = func.env, inherits = FALSE), error = function(e) { FALSE })) { obj <- get(nodeChar, envir = func.env, inherits = FALSE) - if (is.function(obj)) { # If the node is a function call. + if (is.function(obj)) { + # If the node is a function call. funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F, ifnotfound = list(list(NULL)))[[1]] found <- sapply(funcList, function(func) { ifelse(identical(func, obj), TRUE, FALSE) }) - if (sum(found) > 0) { # If function has been examined, ignore. + if (sum(found) > 0) { + # If function has been examined, ignore. break } # Function has not been examined, record it and recursively clean its closure. @@ -495,7 +505,8 @@ cleanClosure <- function(func, checkedFuncs = new.env()) { # environment. First, function's arguments are added to defVars. defVars <- initAccumulator() argNames <- names(as.list(args(func))) - for (i in 1:(length(argNames) - 1)) { # Remove the ending NULL in pairlist. + for (i in 1:(length(argNames) - 1)) { + # Remove the ending NULL in pairlist. addItemToAccumulator(defVars, argNames[i]) } # Recursively examine variables in the function body. From 1cbdd8991898912a8471a7070c472a0edb92487c Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 20 Jul 2015 20:49:38 -0700 Subject: [PATCH 07/32] [SPARK-9201] [ML] Initial integration of MLlib + SparkR using RFormula This exposes the SparkR:::glm() and SparkR:::predict() APIs. It was necessary to change RFormula to silently drop the label column if it was missing from the input dataset, which is kind of a hack but necessary to integrate with the Pipeline API. The umbrella design doc for MLlib + SparkR integration can be viewed here: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit mengxr Author: Eric Liang Closes #7483 from ericl/spark-8774 and squashes the following commits: 3dfac0c [Eric Liang] update 17ef516 [Eric Liang] more comments 1753a0f [Eric Liang] make glm generic b0f50f8 [Eric Liang] equivalence test 550d56d [Eric Liang] export methods c015697 [Eric Liang] second pass 117949a [Eric Liang] comments 5afbc67 [Eric Liang] test label columns 6b7f15f [Eric Liang] Fri Jul 17 14:20:22 PDT 2015 3a63ae5 [Eric Liang] Fri Jul 17 13:41:52 PDT 2015 ce61367 [Eric Liang] Fri Jul 17 13:41:17 PDT 2015 0299c59 [Eric Liang] Fri Jul 17 13:40:32 PDT 2015 e37603f [Eric Liang] Fri Jul 17 12:15:03 PDT 2015 d417d0c [Eric Liang] Merge remote-tracking branch 'upstream/master' into spark-8774 29a2ce7 [Eric Liang] Merge branch 'spark-8774-1' into spark-8774 d1959d2 [Eric Liang] clarify comment 2db68aa [Eric Liang] second round of comments dc3c943 [Eric Liang] address comments 5765ec6 [Eric Liang] fix style checks 1f361b0 [Eric Liang] doc d33211b [Eric Liang] r support fb0826b [Eric Liang] [SPARK-8774] Add R model formula with basic support as a transformer --- R/pkg/DESCRIPTION | 1 + R/pkg/NAMESPACE | 4 + R/pkg/R/generics.R | 4 + R/pkg/R/mllib.R | 73 +++++++++++++++++++ R/pkg/inst/tests/test_mllib.R | 42 +++++++++++ .../apache/spark/ml/feature/RFormula.scala | 14 +++- .../apache/spark/ml/r/SparkRWrappers.scala | 41 +++++++++++ .../spark/ml/feature/RFormulaSuite.scala | 9 +++ 8 files changed, 185 insertions(+), 3 deletions(-) create mode 100644 R/pkg/R/mllib.R create mode 100644 R/pkg/inst/tests/test_mllib.R create mode 100644 mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index d028821534b1a..4949d86d20c91 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -29,6 +29,7 @@ Collate: 'client.R' 'context.R' 'deserialize.R' + 'mllib.R' 'serialize.R' 'sparkR.R' 'utils.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 331307c2077a5..5834813319bfd 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -10,6 +10,10 @@ export("sparkR.init") export("sparkR.stop") export("print.jobj") +# MLlib integration +exportMethods("glm", + "predict") + # Job group lifecycle management methods export("setJobGroup", "clearJobGroup", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index ebe6fbd97ce86..39b5586f7c90e 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -661,3 +661,7 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @rdname column #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) + +#' @rdname glm +#' @export +setGeneric("glm") diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R new file mode 100644 index 0000000000000..258e354081fc1 --- /dev/null +++ b/R/pkg/R/mllib.R @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib.R: Provides methods for MLlib integration + +#' @title S4 class that represents a PipelineModel +#' @param model A Java object reference to the backing Scala PipelineModel +#' @export +setClass("PipelineModel", representation(model = "jobj")) + +#' Fits a generalized linear model +#' +#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. +#' +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~' and '+'. +#' @param data DataFrame for training +#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. +#' @param lambda Regularization parameter +#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) +#' @return a fitted MLlib model +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' data(iris) +#' df <- createDataFrame(sqlContext, iris) +#' model <- glm(Sepal_Length ~ Sepal_Width, df) +#'} +setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), + function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) { + family <- match.arg(family) + model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "fitRModelFormula", deparse(formula), data@sdf, family, lambda, + alpha) + return(new("PipelineModel", model = model)) + }) + +#' Make predictions from a model +#' +#' Makes predictions from a model produced by glm(), similarly to R's predict(). +#' +#' @param model A fitted MLlib model +#' @param newData DataFrame for testing +#' @return DataFrame containing predicted values +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' model <- glm(y ~ x, trainingData) +#' predicted <- predict(model, testData) +#' showDF(predicted) +#'} +setMethod("predict", signature(object = "PipelineModel"), + function(object, newData) { + return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) + }) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R new file mode 100644 index 0000000000000..a492763344ae6 --- /dev/null +++ b/R/pkg/inst/tests/test_mllib.R @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib functions") + +# Tests for MLlib functions in SparkR + +sc <- sparkR.init() + +sqlContext <- sparkRSQL.init(sc) + +test_that("glm and predict", { + training <- createDataFrame(sqlContext, iris) + test <- select(training, "Sepal_Length") + model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") + prediction <- predict(model, test) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") +}) + +test_that("predictions match with native glm", { + training <- createDataFrame(sqlContext, iris) + model <- glm(Sepal_Width ~ Sepal_Length, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals) +}) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 56169f2a01fc9..f7b46efa10e90 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -73,12 +73,16 @@ class RFormula(override val uid: String) val withFeatures = transformFeatures.transformSchema(schema) if (hasLabelCol(schema)) { withFeatures - } else { + } else if (schema.exists(_.name == parsedFormula.get.label)) { val nullable = schema(parsedFormula.get.label).dataType match { case _: NumericType | BooleanType => false case _ => true } StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable)) + } else { + // Ignore the label field. This is a hack so that this transformer can also work on test + // datasets in a Pipeline. + withFeatures } } @@ -92,10 +96,10 @@ class RFormula(override val uid: String) override def toString: String = s"RFormula(${get(formula)})" private def transformLabel(dataset: DataFrame): DataFrame = { + val labelName = parsedFormula.get.label if (hasLabelCol(dataset.schema)) { dataset - } else { - val labelName = parsedFormula.get.label + } else if (dataset.schema.exists(_.name == labelName)) { dataset.schema(labelName).dataType match { case _: NumericType | BooleanType => dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType)) @@ -103,6 +107,10 @@ class RFormula(override val uid: String) case other => throw new IllegalArgumentException("Unsupported type for label: " + other) } + } else { + // Ignore the label field. This is a hack so that this transformer can also work on test + // datasets in a Pipeline. + dataset } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala new file mode 100644 index 0000000000000..1ee080641e3e3 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.api.r + +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.sql.DataFrame + +private[r] object SparkRWrappers { + def fitRModelFormula( + value: String, + df: DataFrame, + family: String, + lambda: Double, + alpha: Double): PipelineModel = { + val formula = new RFormula().setFormula(value) + val estimator = family match { + case "gaussian" => new LinearRegression().setRegParam(lambda).setElasticNetParam(alpha) + case "binomial" => new LogisticRegression().setRegParam(lambda).setElasticNetParam(alpha) + } + val pipeline = new Pipeline().setStages(Array(formula, estimator)) + pipeline.fit(df) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index fa8611b243a9f..79c4ccf02d4e0 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -74,6 +74,15 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("allow missing label column for test datasets") { + val formula = new RFormula().setFormula("y ~ x").setLabelCol("label") + val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y") + val resultSchema = formula.transformSchema(original.schema) + assert(resultSchema.length == 3) + assert(!resultSchema.exists(_.name == "label")) + assert(resultSchema.toString == formula.transform(original).schema.toString) + } + // TODO(ekl) enable after we implement string label support // test("transform string label") { // val formula = new RFormula().setFormula("name ~ id") From a3c7a3ce32697ad293b8bcaf29f9384c8255b37f Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 20 Jul 2015 22:08:12 -0700 Subject: [PATCH 08/32] [SPARK-9132][SPARK-9163][SQL] codegen conv Jira: https://issues.apache.org/jira/browse/SPARK-9132 https://issues.apache.org/jira/browse/SPARK-9163 rxin as you proposed in the Jira ticket, I just moved the logic to a separate object. I haven't changed anything of the logic of `NumberConverter`. Author: Tarek Auel Closes #7552 from tarekauel/SPARK-9163 and squashes the following commits: 40dcde9 [Tarek Auel] [SPARK-9132][SPARK-9163][SQL] style fix fa985bd [Tarek Auel] [SPARK-9132][SPARK-9163][SQL] codegen conv --- .../spark/sql/catalyst/expressions/math.scala | 204 ++++-------------- .../sql/catalyst/util/NumberConverter.scala | 176 +++++++++++++++ .../expressions/MathFunctionsSuite.scala | 4 +- .../catalyst/util/NumberConverterSuite.scala | 40 ++++ 4 files changed, 263 insertions(+), 161 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index 7a9be02ba45b3..68cca0ad3d067 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.NumberConverter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -164,7 +165,7 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH" * @param toBaseExpr to which base */ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends Expression with ImplicitCastInputTypes { override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable @@ -179,169 +180,54 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre /** Returns the result of evaluating this expression on a given input Row */ override def eval(input: InternalRow): Any = { val num = numExpr.eval(input) - val fromBase = fromBaseExpr.eval(input) - val toBase = toBaseExpr.eval(input) - if (num == null || fromBase == null || toBase == null) { - null - } else { - conv( - num.asInstanceOf[UTF8String].getBytes, - fromBase.asInstanceOf[Int], - toBase.asInstanceOf[Int]) - } - } - - private val value = new Array[Byte](64) - - /** - * Divide x by m as if x is an unsigned 64-bit integer. Examples: - * unsignedLongDiv(-1, 2) == Long.MAX_VALUE unsignedLongDiv(6, 3) == 2 - * unsignedLongDiv(0, 5) == 0 - * - * @param x is treated as unsigned - * @param m is treated as signed - */ - private def unsignedLongDiv(x: Long, m: Int): Long = { - if (x >= 0) { - x / m - } else { - // Let uval be the value of the unsigned long with the same bits as x - // Two's complement => x = uval - 2*MAX - 2 - // => uval = x + 2*MAX + 2 - // Now, use the fact: (a+b)/c = a/c + b/c + (a%c+b%c)/c - x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m - } - } - - /** - * Decode v into value[]. - * - * @param v is treated as an unsigned 64-bit integer - * @param radix must be between MIN_RADIX and MAX_RADIX - */ - private def decode(v: Long, radix: Int): Unit = { - var tmpV = v - java.util.Arrays.fill(value, 0.asInstanceOf[Byte]) - var i = value.length - 1 - while (tmpV != 0) { - val q = unsignedLongDiv(tmpV, radix) - value(i) = (tmpV - q * radix).asInstanceOf[Byte] - tmpV = q - i -= 1 - } - } - - /** - * Convert value[] into a long. On overflow, return -1 (as mySQL does). If a - * negative digit is found, ignore the suffix starting there. - * - * @param radix must be between MIN_RADIX and MAX_RADIX - * @param fromPos is the first element that should be conisdered - * @return the result should be treated as an unsigned 64-bit integer. - */ - private def encode(radix: Int, fromPos: Int): Long = { - var v: Long = 0L - val bound = unsignedLongDiv(-1 - radix, radix) // Possible overflow once - // val - // exceeds this value - var i = fromPos - while (i < value.length && value(i) >= 0) { - if (v >= bound) { - // Check for overflow - if (unsignedLongDiv(-1 - value(i), radix) < v) { - return -1 + if (num != null) { + val fromBase = fromBaseExpr.eval(input) + if (fromBase != null) { + val toBase = toBaseExpr.eval(input) + if (toBase != null) { + NumberConverter.convert( + num.asInstanceOf[UTF8String].getBytes, + fromBase.asInstanceOf[Int], + toBase.asInstanceOf[Int]) + } else { + null } - } - v = v * radix + value(i) - i += 1 - } - v - } - - /** - * Convert the bytes in value[] to the corresponding chars. - * - * @param radix must be between MIN_RADIX and MAX_RADIX - * @param fromPos is the first nonzero element - */ - private def byte2char(radix: Int, fromPos: Int): Unit = { - var i = fromPos - while (i < value.length) { - value(i) = Character.toUpperCase(Character.forDigit(value(i), radix)).asInstanceOf[Byte] - i += 1 - } - } - - /** - * Convert the chars in value[] to the corresponding integers. Convert invalid - * characters to -1. - * - * @param radix must be between MIN_RADIX and MAX_RADIX - * @param fromPos is the first nonzero element - */ - private def char2byte(radix: Int, fromPos: Int): Unit = { - var i = fromPos - while ( i < value.length) { - value(i) = Character.digit(value(i), radix).asInstanceOf[Byte] - i += 1 - } - } - - /** - * Convert numbers between different number bases. If toBase>0 the result is - * unsigned, otherwise it is signed. - * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv - */ - private def conv(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = { - if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX - || Math.abs(toBase) < Character.MIN_RADIX - || Math.abs(toBase) > Character.MAX_RADIX) { - return null - } - - if (n.length == 0) { - return null - } - - var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0) - - // Copy the digits in the right side of the array - var i = 1 - while (i <= n.length - first) { - value(value.length - i) = n(n.length - i) - i += 1 - } - char2byte(fromBase, value.length - n.length + first) - - // Do the conversion by going through a 64 bit integer - var v = encode(fromBase, value.length - n.length + first) - if (negative && toBase > 0) { - if (v < 0) { - v = -1 } else { - v = -v + null } + } else { + null } - if (toBase < 0 && v < 0) { - v = -v - negative = true - } - decode(v, Math.abs(toBase)) - - // Find the first non-zero digit or the last digits if all are zero. - val firstNonZeroPos = { - val firstNonZero = value.indexWhere( _ != 0) - if (firstNonZero != -1) firstNonZero else value.length - 1 - } - - byte2char(Math.abs(toBase), firstNonZeroPos) + } - var resultStartPos = firstNonZeroPos - if (negative && toBase < 0) { - resultStartPos = firstNonZeroPos - 1 - value(resultStartPos) = '-' - } - UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, resultStartPos, value.length)) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val numGen = numExpr.gen(ctx) + val from = fromBaseExpr.gen(ctx) + val to = toBaseExpr.gen(ctx) + + val numconv = NumberConverter.getClass.getName.stripSuffix("$") + s""" + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${numGen.code} + boolean ${ev.isNull} = ${numGen.isNull}; + if (!${ev.isNull}) { + ${from.code} + if (!${from.isNull}) { + ${to.code} + if (!${to.isNull}) { + ${ev.primitive} = $numconv.convert(${numGen.primitive}.getBytes(), + ${from.primitive}, ${to.primitive}); + if (${ev.primitive} == null) { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } + """ } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala new file mode 100644 index 0000000000000..9fefc5656aac0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/NumberConverter.scala @@ -0,0 +1,176 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.unsafe.types.UTF8String + +object NumberConverter { + + private val value = new Array[Byte](64) + + /** + * Divide x by m as if x is an unsigned 64-bit integer. Examples: + * unsignedLongDiv(-1, 2) == Long.MAX_VALUE unsignedLongDiv(6, 3) == 2 + * unsignedLongDiv(0, 5) == 0 + * + * @param x is treated as unsigned + * @param m is treated as signed + */ + private def unsignedLongDiv(x: Long, m: Int): Long = { + if (x >= 0) { + x / m + } else { + // Let uval be the value of the unsigned long with the same bits as x + // Two's complement => x = uval - 2*MAX - 2 + // => uval = x + 2*MAX + 2 + // Now, use the fact: (a+b)/c = a/c + b/c + (a%c+b%c)/c + x / m + 2 * (Long.MaxValue / m) + 2 / m + (x % m + 2 * (Long.MaxValue % m) + 2 % m) / m + } + } + + /** + * Decode v into value[]. + * + * @param v is treated as an unsigned 64-bit integer + * @param radix must be between MIN_RADIX and MAX_RADIX + */ + private def decode(v: Long, radix: Int): Unit = { + var tmpV = v + java.util.Arrays.fill(value, 0.asInstanceOf[Byte]) + var i = value.length - 1 + while (tmpV != 0) { + val q = unsignedLongDiv(tmpV, radix) + value(i) = (tmpV - q * radix).asInstanceOf[Byte] + tmpV = q + i -= 1 + } + } + + /** + * Convert value[] into a long. On overflow, return -1 (as mySQL does). If a + * negative digit is found, ignore the suffix starting there. + * + * @param radix must be between MIN_RADIX and MAX_RADIX + * @param fromPos is the first element that should be conisdered + * @return the result should be treated as an unsigned 64-bit integer. + */ + private def encode(radix: Int, fromPos: Int): Long = { + var v: Long = 0L + val bound = unsignedLongDiv(-1 - radix, radix) // Possible overflow once + // val + // exceeds this value + var i = fromPos + while (i < value.length && value(i) >= 0) { + if (v >= bound) { + // Check for overflow + if (unsignedLongDiv(-1 - value(i), radix) < v) { + return -1 + } + } + v = v * radix + value(i) + i += 1 + } + v + } + + /** + * Convert the bytes in value[] to the corresponding chars. + * + * @param radix must be between MIN_RADIX and MAX_RADIX + * @param fromPos is the first nonzero element + */ + private def byte2char(radix: Int, fromPos: Int): Unit = { + var i = fromPos + while (i < value.length) { + value(i) = Character.toUpperCase(Character.forDigit(value(i), radix)).asInstanceOf[Byte] + i += 1 + } + } + + /** + * Convert the chars in value[] to the corresponding integers. Convert invalid + * characters to -1. + * + * @param radix must be between MIN_RADIX and MAX_RADIX + * @param fromPos is the first nonzero element + */ + private def char2byte(radix: Int, fromPos: Int): Unit = { + var i = fromPos + while ( i < value.length) { + value(i) = Character.digit(value(i), radix).asInstanceOf[Byte] + i += 1 + } + } + + /** + * Convert numbers between different number bases. If toBase>0 the result is + * unsigned, otherwise it is signed. + * NB: This logic is borrowed from org.apache.hadoop.hive.ql.ud.UDFConv + */ + def convert(n: Array[Byte] , fromBase: Int, toBase: Int ): UTF8String = { + if (fromBase < Character.MIN_RADIX || fromBase > Character.MAX_RADIX + || Math.abs(toBase) < Character.MIN_RADIX + || Math.abs(toBase) > Character.MAX_RADIX) { + return null + } + + if (n.length == 0) { + return null + } + + var (negative, first) = if (n(0) == '-') (true, 1) else (false, 0) + + // Copy the digits in the right side of the array + var i = 1 + while (i <= n.length - first) { + value(value.length - i) = n(n.length - i) + i += 1 + } + char2byte(fromBase, value.length - n.length + first) + + // Do the conversion by going through a 64 bit integer + var v = encode(fromBase, value.length - n.length + first) + if (negative && toBase > 0) { + if (v < 0) { + v = -1 + } else { + v = -v + } + } + if (toBase < 0 && v < 0) { + v = -v + negative = true + } + decode(v, Math.abs(toBase)) + + // Find the first non-zero digit or the last digits if all are zero. + val firstNonZeroPos = { + val firstNonZero = value.indexWhere( _ != 0) + if (firstNonZero != -1) firstNonZero else value.length - 1 + } + + byte2char(Math.abs(toBase), firstNonZeroPos) + + var resultStartPos = firstNonZeroPos + if (negative && toBase < 0) { + resultStartPos = firstNonZeroPos - 1 + value(resultStartPos) = '-' + } + UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, resultStartPos, value.length)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 04acd5b5ff4d1..a2b0fad7b7a04 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -115,8 +115,8 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(-16)), "-F") checkEvaluation(Conv(Literal("-15"), Literal(10), Literal(16)), "FFFFFFFFFFFFFFF1") checkEvaluation(Conv(Literal("big"), Literal(36), Literal(16)), "3A48") - checkEvaluation(Conv(Literal(null), Literal(36), Literal(16)), null) - checkEvaluation(Conv(Literal("3"), Literal(null), Literal(16)), null) + checkEvaluation(Conv(Literal.create(null, StringType), Literal(36), Literal(16)), null) + checkEvaluation(Conv(Literal("3"), Literal.create(null, IntegerType), Literal(16)), null) checkEvaluation( Conv(Literal("1234"), Literal(10), Literal(37)), null) checkEvaluation( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala new file mode 100644 index 0000000000000..13265a1ff1c7f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/NumberConverterSuite.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.NumberConverter.convert +import org.apache.spark.unsafe.types.UTF8String + +class NumberConverterSuite extends SparkFunSuite { + + private[this] def checkConv(n: String, fromBase: Int, toBase: Int, expected: String): Unit = { + assert(convert(UTF8String.fromString(n).getBytes, fromBase, toBase) === + UTF8String.fromString(expected)) + } + + test("convert") { + checkConv("3", 10, 2, "11") + checkConv("-15", 10, -16, "-F") + checkConv("-15", 10, 16, "FFFFFFFFFFFFFFF1") + checkConv("big", 36, 16, "3A48") + checkConv("9223372036854775807", 36, 16, "FFFFFFFFFFFFFFFF") + checkConv("11abc", 10, 16, "B") + } + +} From 4d97be95300f729391c17b4c162e3c7fba09b8bf Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 20 Jul 2015 22:15:10 -0700 Subject: [PATCH 09/32] [SPARK-9204][ML] Add default params test for linearyregression suite Author: Holden Karau Closes #7553 from holdenk/SPARK-9204-add-default-params-test-to-linear-regression and squashes the following commits: 630ba19 [Holden Karau] style fix faa08a3 [Holden Karau] Add default params test for linearyregression suite --- .../ml/regression/LinearRegressionSuite.scala | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 374002c5b4fdd..7cdda3db88ad1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.mllib.linalg.{DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ @@ -55,6 +56,30 @@ class LinearRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { } + test("params") { + ParamsSuite.checkParams(new LinearRegression) + val model = new LinearRegressionModel("linearReg", Vectors.dense(0.0), 0.0) + ParamsSuite.checkParams(model) + } + + test("linear regression: default params") { + val lir = new LinearRegression + assert(lir.getLabelCol === "label") + assert(lir.getFeaturesCol === "features") + assert(lir.getPredictionCol === "prediction") + assert(lir.getRegParam === 0.0) + assert(lir.getElasticNetParam === 0.0) + assert(lir.getFitIntercept) + val model = lir.fit(dataset) + model.transform(dataset) + .select("label", "prediction") + .collect() + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.intercept !== 0.0) + assert(model.hasParent) + } + test("linear regression with intercept without regularization") { val trainer = new LinearRegression val model = trainer.fit(dataset) From c032b0bf92130dc4facb003f0deaeb1228aefded Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 20 Jul 2015 22:38:05 -0700 Subject: [PATCH 10/32] [SPARK-8797] [SPARK-9146] [SPARK-9145] [SPARK-9147] Support NaN ordering and equality comparisons in Spark SQL This patch addresses an issue where queries that sorted float or double columns containing NaN values could fail with "Comparison method violates its general contract!" errors from TimSort. The root of this problem is that `NaN > anything`, `NaN == anything`, and `NaN < anything` all return `false`. Per the design specified in SPARK-9079, we have decided that `NaN = NaN` should return true and that NaN should appear last when sorting in ascending order (i.e. it is larger than any other numeric value). In addition to implementing these semantics, this patch also adds canonicalization of NaN values in UnsafeRow, which is necessary in order to be able to do binary equality comparisons on equal NaNs that might have different bit representations (see SPARK-9147). Author: Josh Rosen Closes #7194 from JoshRosen/nan and squashes the following commits: 983d4fc [Josh Rosen] Merge remote-tracking branch 'origin/master' into nan 88bd73c [Josh Rosen] Fix Row.equals() a702e2e [Josh Rosen] normalization -> canonicalization a7267cf [Josh Rosen] Normalize NaNs in UnsafeRow fe629ae [Josh Rosen] Merge remote-tracking branch 'origin/master' into nan fbb2a29 [Josh Rosen] Fix NaN comparisons in BinaryComparison expressions c1fd4fe [Josh Rosen] Fold NaN test into existing test framework b31eb19 [Josh Rosen] Uncomment failing tests 7fe67af [Josh Rosen] Support NaN == NaN (SPARK-9145) 58bad2c [Josh Rosen] Revert "Compare rows' string representations to work around NaN incomparability." fc6b4d2 [Josh Rosen] Update CodeGenerator 3998ef2 [Josh Rosen] Remove unused code a2ba2e7 [Josh Rosen] Fix prefix comparision for NaNs a30d371 [Josh Rosen] Compare rows' string representations to work around NaN incomparability. 6f03f85 [Josh Rosen] Fix bug in Double / Float ordering 42a1ad5 [Josh Rosen] Stop filtering NaNs in UnsafeExternalSortSuite bfca524 [Josh Rosen] Change ordering so that NaN is maximum value. 8d7be61 [Josh Rosen] Update randomized test to use ScalaTest's assume() b20837b [Josh Rosen] Add failing test for new NaN comparision ordering 5b88b2b [Josh Rosen] Fix compilation of CodeGenerationSuite d907b5b [Josh Rosen] Merge remote-tracking branch 'origin/master' into nan 630ebc5 [Josh Rosen] Specify an ordering for NaN values. 9bf195a [Josh Rosen] Re-enable NaNs in CodeGenerationSuite to produce more regression tests 13fc06a [Josh Rosen] Add regression test for NaN sorting issue f9efbb5 [Josh Rosen] Fix ORDER BY NULL e7dc4fb [Josh Rosen] Add very generic test for ordering 7d5c13e [Josh Rosen] Add regression test for SPARK-8782 (ORDER BY NULL) b55875a [Josh Rosen] Generate doubles and floats over entire possible range. 5acdd5c [Josh Rosen] Infinity and NaN are interesting. ab76cbd [Josh Rosen] Move code to Catalyst package. d2b4a4a [Josh Rosen] Add random data generator test utilities to Spark SQL. --- .../unsafe/sort/PrefixComparators.java | 5 ++- .../scala/org/apache/spark/util/Utils.scala | 28 +++++++++++++ .../org/apache/spark/util/UtilsSuite.scala | 31 +++++++++++++++ .../unsafe/sort/PrefixComparatorsSuite.scala | 25 ++++++++++++ .../sql/catalyst/expressions/UnsafeRow.java | 6 +++ .../main/scala/org/apache/spark/sql/Row.scala | 24 ++++++++---- .../expressions/codegen/CodeGenerator.scala | 4 ++ .../sql/catalyst/expressions/predicates.scala | 22 +++++++++-- .../apache/spark/sql/types/DoubleType.scala | 5 ++- .../apache/spark/sql/types/FloatType.scala | 5 ++- .../expressions/CodeGenerationSuite.scala | 39 +++++++++++++++++++ .../catalyst/expressions/PredicateSuite.scala | 13 ++++--- .../expressions/UnsafeRowConverterSuite.scala | 22 +++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 22 +++++++++++ .../scala/org/apache/spark/sql/RowSuite.scala | 12 ++++++ .../execution/UnsafeExternalSortSuite.scala | 6 +-- 16 files changed, 243 insertions(+), 26 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index 438742565c51d..bf1bc5dffba78 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -23,6 +23,7 @@ import org.apache.spark.annotation.Private; import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.util.Utils; @Private public class PrefixComparators { @@ -82,7 +83,7 @@ public static final class FloatPrefixComparator extends PrefixComparator { public int compare(long aPrefix, long bPrefix) { float a = Float.intBitsToFloat((int) aPrefix); float b = Float.intBitsToFloat((int) bPrefix); - return (a < b) ? -1 : (a > b) ? 1 : 0; + return Utils.nanSafeCompareFloats(a, b); } public long computePrefix(float value) { @@ -97,7 +98,7 @@ public static final class DoublePrefixComparator extends PrefixComparator { public int compare(long aPrefix, long bPrefix) { double a = Double.longBitsToDouble(aPrefix); double b = Double.longBitsToDouble(bPrefix); - return (a < b) ? -1 : (a > b) ? 1 : 0; + return Utils.nanSafeCompareDoubles(a, b); } public long computePrefix(double value) { diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index e6374f17d858f..c5816949cd360 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1586,6 +1586,34 @@ private[spark] object Utils extends Logging { hashAbs } + /** + * NaN-safe version of [[java.lang.Double.compare()]] which allows NaN values to be compared + * according to semantics where NaN == NaN and NaN > any non-NaN double. + */ + def nanSafeCompareDoubles(x: Double, y: Double): Int = { + val xIsNan: Boolean = java.lang.Double.isNaN(x) + val yIsNan: Boolean = java.lang.Double.isNaN(y) + if ((xIsNan && yIsNan) || (x == y)) 0 + else if (xIsNan) 1 + else if (yIsNan) -1 + else if (x > y) 1 + else -1 + } + + /** + * NaN-safe version of [[java.lang.Float.compare()]] which allows NaN values to be compared + * according to semantics where NaN == NaN and NaN > any non-NaN float. + */ + def nanSafeCompareFloats(x: Float, y: Float): Int = { + val xIsNan: Boolean = java.lang.Float.isNaN(x) + val yIsNan: Boolean = java.lang.Float.isNaN(y) + if ((xIsNan && yIsNan) || (x == y)) 0 + else if (xIsNan) 1 + else if (yIsNan) -1 + else if (x > y) 1 + else -1 + } + /** Returns the system properties map that is thread-safe to iterator over. It gets the * properties which have been set explicitly, as well as those for which only a default value * has been defined. */ diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index c7638507c88c6..8f7e402d5f2a6 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.util import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream} +import java.lang.{Double => JDouble, Float => JFloat} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols @@ -689,4 +690,34 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { // scalastyle:on println assert(buffer.toString === "t circular test circular\n") } + + test("nanSafeCompareDoubles") { + def shouldMatchDefaultOrder(a: Double, b: Double): Unit = { + assert(Utils.nanSafeCompareDoubles(a, b) === JDouble.compare(a, b)) + assert(Utils.nanSafeCompareDoubles(b, a) === JDouble.compare(b, a)) + } + shouldMatchDefaultOrder(0d, 0d) + shouldMatchDefaultOrder(0d, 1d) + shouldMatchDefaultOrder(Double.MinValue, Double.MaxValue) + assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NaN) === 0) + assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.PositiveInfinity) === 1) + assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NegativeInfinity) === 1) + assert(Utils.nanSafeCompareDoubles(Double.PositiveInfinity, Double.NaN) === -1) + assert(Utils.nanSafeCompareDoubles(Double.NegativeInfinity, Double.NaN) === -1) + } + + test("nanSafeCompareFloats") { + def shouldMatchDefaultOrder(a: Float, b: Float): Unit = { + assert(Utils.nanSafeCompareFloats(a, b) === JFloat.compare(a, b)) + assert(Utils.nanSafeCompareFloats(b, a) === JFloat.compare(b, a)) + } + shouldMatchDefaultOrder(0f, 0f) + shouldMatchDefaultOrder(1f, 1f) + shouldMatchDefaultOrder(Float.MinValue, Float.MaxValue) + assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NaN) === 0) + assert(Utils.nanSafeCompareFloats(Float.NaN, Float.PositiveInfinity) === 1) + assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NegativeInfinity) === 1) + assert(Utils.nanSafeCompareFloats(Float.PositiveInfinity, Float.NaN) === -1) + assert(Utils.nanSafeCompareFloats(Float.NegativeInfinity, Float.NaN) === -1) + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index dd505dfa7d758..dc03e374b51db 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -47,4 +47,29 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks { forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2) } forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) } } + + test("float prefix comparator handles NaN properly") { + val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001) + val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff) + assert(nan1.isNaN) + assert(nan2.isNaN) + val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1) + val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2) + assert(nan1Prefix === nan2Prefix) + val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue) + assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1) + } + + test("double prefix comparator handles NaNs properly") { + val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L) + val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL) + assert(nan1.isNaN) + assert(nan2.isNaN) + val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1) + val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2) + assert(nan1Prefix === nan2Prefix) + val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue) + assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1) + } + } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 87294a0e21441..8cd9e7bc60a03 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -215,6 +215,9 @@ public void setLong(int ordinal, long value) { public void setDouble(int ordinal, double value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); + if (Double.isNaN(value)) { + value = Double.NaN; + } PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value); } @@ -243,6 +246,9 @@ public void setByte(int ordinal, byte value) { public void setFloat(int ordinal, float value) { assertIndexIsValid(ordinal); setNotNullAt(ordinal); + if (Float.isNaN(value)) { + value = Float.NaN; + } PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 2cb64d00935de..91449479fa539 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -403,20 +403,28 @@ trait Row extends Serializable { if (!isNullAt(i)) { val o1 = get(i) val o2 = other.get(i) - if (o1.isInstanceOf[Array[Byte]]) { - // handle equality of Array[Byte] - val b1 = o1.asInstanceOf[Array[Byte]] - if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { return false } - } else if (o1 != o2) { - return false } } i += 1 } - return true + true } /* ---------------------- utility methods for Scala ---------------------- */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 10f411ff7451a..606f770cb4f7b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -194,6 +194,8 @@ class CodeGenContext { */ def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match { case BinaryType => s"java.util.Arrays.equals($c1, $c2)" + case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2)) || $c1 == $c2" + case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2)) || $c1 == $c2" case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2" case other => s"$c1.equals($c2)" } @@ -204,6 +206,8 @@ class CodeGenContext { def genComp(dataType: DataType, c1: String, c2: String): String = dataType match { // java boolean doesn't support > or < operator case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))" + case DoubleType => s"org.apache.spark.util.Utils.nanSafeCompareDoubles($c1, $c2)" + case FloatType => s"org.apache.spark.util.Utils.nanSafeCompareFloats($c1, $c2)" // use c1 - c2 may overflow case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1 : 0)" case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1, $c2)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 40ec3df224ce1..a53ec31ee6a4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils object InterpretedPredicate { @@ -222,7 +223,9 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P abstract class BinaryComparison extends BinaryOperator with Predicate { override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - if (ctx.isPrimitiveType(left.dataType)) { + if (ctx.isPrimitiveType(left.dataType) + && left.dataType != FloatType + && left.dataType != DoubleType) { // faster version defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2") } else { @@ -254,8 +257,15 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison override def symbol: String = "=" protected override def nullSafeEval(input1: Any, input2: Any): Any = { - if (left.dataType != BinaryType) input1 == input2 - else java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) + if (left.dataType == FloatType) { + Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0 + } else if (left.dataType == DoubleType) { + Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0 + } else if (left.dataType != BinaryType) { + input1 == input2 + } else { + java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) + } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -280,7 +290,11 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } else if (input1 == null || input2 == null) { false } else { - if (left.dataType != BinaryType) { + if (left.dataType == FloatType) { + Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float]) == 0 + } else if (left.dataType == DoubleType) { + Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) == 0 + } else if (left.dataType != BinaryType) { input1 == input2 } else { java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index 986c2ab055386..2a1bf0938e5a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -37,7 +38,9 @@ class DoubleType private() extends FractionalType { @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Double]] private[sql] val fractional = implicitly[Fractional[Double]] - private[sql] val ordering = implicitly[Ordering[InternalType]] + private[sql] val ordering = new Ordering[Double] { + override def compare(x: Double, y: Double): Int = Utils.nanSafeCompareDoubles(x, y) + } private[sql] val asIntegral = DoubleAsIfIntegral /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala index 9bd48ece83a1c..08e22252aef82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -37,7 +38,9 @@ class FloatType private() extends FractionalType { @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } private[sql] val numeric = implicitly[Numeric[Float]] private[sql] val fractional = implicitly[Fractional[Float]] - private[sql] val ordering = implicitly[Ordering[InternalType]] + private[sql] val ordering = new Ordering[Float] { + override def compare(x: Float, y: Float): Int = Utils.nanSafeCompareFloats(x, y) + } private[sql] val asIntegral = FloatAsIfIntegral /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index e05218a23aa73..f4fbc49677ca3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -17,9 +17,14 @@ package org.apache.spark.sql.catalyst.expressions +import scala.math._ + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.types.{DataTypeTestUtils, NullType, StructField, StructType} /** * Additional tests for code generation. @@ -43,6 +48,40 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { futures.foreach(Await.result(_, 10.seconds)) } + // Test GenerateOrdering for all common types. For each type, we construct random input rows that + // contain two columns of that type, then for pairs of randomly-generated rows we check that + // GenerateOrdering agrees with RowOrdering. + (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType => + test(s"GenerateOrdering with $dataType") { + val rowOrdering = RowOrdering.forSchema(Seq(dataType, dataType)) + val genOrdering = GenerateOrdering.generate( + BoundReference(0, dataType, nullable = true).asc :: + BoundReference(1, dataType, nullable = true).asc :: Nil) + val rowType = StructType( + StructField("a", dataType, nullable = true) :: + StructField("b", dataType, nullable = true) :: Nil) + val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false) + assume(maybeDataGenerator.isDefined) + val randGenerator = maybeDataGenerator.get + val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType) + for (_ <- 1 to 50) { + val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow] + val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow] + withClue(s"a = $a, b = $b") { + assert(genOrdering.compare(a, a) === 0) + assert(genOrdering.compare(b, b) === 0) + assert(rowOrdering.compare(a, a) === 0) + assert(rowOrdering.compare(b, b) === 0) + assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b, a))) + assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b, a))) + assert( + signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)), + "Generated and non-generated orderings should agree") + } + } + } + } + test("SPARK-8443: split wide projections into blocks due to JVM code size limit") { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 2173a0c25c645..0bc2812a5dc83 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -136,11 +136,14 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true) } - private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_)) - private val largeValues = Seq(2, Decimal(2), Array(2.toByte), "b").map(Literal(_)) - - private val equalValues1 = smallValues - private val equalValues2 = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_)) + private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d).map(Literal(_)) + private val largeValues = + Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN).map(Literal(_)) + + private val equalValues1 = + Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_)) + private val equalValues2 = + Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_)) test("BinaryComparison: <") { for (i <- 0 until smallValues.length) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index d00aeb4dfbf47..dff5faf9f6ec8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -316,4 +316,26 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11)) } + test("NaN canonicalization") { + val fieldTypes: Array[DataType] = Array(FloatType, DoubleType) + + val row1 = new SpecificMutableRow(fieldTypes) + row1.setFloat(0, java.lang.Float.intBitsToFloat(0x7f800001)) + row1.setDouble(1, java.lang.Double.longBitsToDouble(0x7ff0000000000001L)) + + val row2 = new SpecificMutableRow(fieldTypes) + row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff)) + row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)) + + val converter = new UnsafeRowConverter(fieldTypes) + val row1Buffer = new Array[Byte](converter.getSizeRequirement(row1)) + val row2Buffer = new Array[Byte](converter.getSizeRequirement(row2)) + converter.writeRow( + row1, row1Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row1Buffer.length, null) + converter.writeRow( + row2, row2Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row2Buffer.length, null) + + assert(row1Buffer.toSeq === row2Buffer.toSeq) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 192cc0a6e5d7c..f67f2c60c0e16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.io.File import scala.language.postfixOps +import scala.util.Random import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.functions._ @@ -742,6 +743,27 @@ class DataFrameSuite extends QueryTest with SQLTestUtils { df.col("t.``") } + test("SPARK-8797: sort by float column containing NaN should not crash") { + val inputData = Seq.fill(10)(Tuple1(Float.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toFloat)) + val df = Random.shuffle(inputData).toDF("a") + df.orderBy("a").collect() + } + + test("SPARK-8797: sort by double column containing NaN should not crash") { + val inputData = Seq.fill(10)(Tuple1(Double.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toDouble)) + val df = Random.shuffle(inputData).toDF("a") + df.orderBy("a").collect() + } + + test("NaN is greater than all other non-NaN numeric values") { + val maxDouble = Seq(Double.NaN, Double.PositiveInfinity, Double.MaxValue) + .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first() + assert(java.lang.Double.isNaN(maxDouble.getDouble(0))) + val maxFloat = Seq(Float.NaN, Float.PositiveInfinity, Float.MaxValue) + .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first() + assert(java.lang.Float.isNaN(maxFloat.getFloat(0))) + } + test("SPARK-8072: Better Exception for Duplicate Columns") { // only one duplicate column present val e = intercept[org.apache.spark.sql.AnalysisException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index d84b57af9c882..7cc6ffd7548d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -73,4 +73,16 @@ class RowSuite extends SparkFunSuite { row.getAs[Int]("c") } } + + test("float NaN == NaN") { + val r1 = Row(Float.NaN) + val r2 = Row(Float.NaN) + assert(r1 === r2) + } + + test("double NaN == NaN") { + val r1 = Row(Double.NaN) + val r2 = Row(Double.NaN) + assert(r1 === r2) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala index 4f4c1f28564cb..5fe73f7e0b072 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala @@ -83,11 +83,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll { randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable) ) { test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") { - val inputData = Seq.fill(1000)(randomDataGenerator()).filter { - case d: Double => !d.isNaN - case f: Float => !java.lang.Float.isNaN(f) - case x => true - } + val inputData = Seq.fill(1000)(randomDataGenerator()) val inputDf = TestSQLContext.createDataFrame( TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))), StructType(StructField("a", dataType, nullable = true) :: Nil) From 560b355ccd038ca044726c9c9fcffd14d02e6696 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 20 Jul 2015 22:43:30 -0700 Subject: [PATCH 11/32] [SPARK-9157] [SQL] codegen substring https://issues.apache.org/jira/browse/SPARK-9157 Author: Tarek Auel Closes #7534 from tarekauel/SPARK-9157 and squashes the following commits: e65e3e9 [Tarek Auel] [SPARK-9157] indent fix 44e89f8 [Tarek Auel] [SPARK-9157] use EMPTY_UTF8 37d54c4 [Tarek Auel] Merge branch 'master' into SPARK-9157 60732ea [Tarek Auel] [SPARK-9157] created substringSQL in UTF8String 18c3576 [Tarek Auel] [SPARK-9157][SQL] remove slice pos 1a2e611 [Tarek Auel] [SPARK-9157][SQL] codegen substring --- .../expressions/stringOperations.scala | 87 ++++++++++--------- .../apache/spark/unsafe/types/UTF8String.java | 12 +++ .../spark/unsafe/types/UTF8StringSuite.java | 19 ++++ 3 files changed, 75 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 5c1908d55576a..438215e8e6e37 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -640,7 +640,7 @@ case class StringSplit(str: Expression, pattern: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends Expression with ImplicitCastInputTypes { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) @@ -649,58 +649,59 @@ case class Substring(str: Expression, pos: Expression, len: Expression) override def foldable: Boolean = str.foldable && pos.foldable && len.foldable override def nullable: Boolean = str.nullable || pos.nullable || len.nullable - override def dataType: DataType = { - if (!resolved) { - throw new UnresolvedException(this, s"Cannot resolve since $children are not resolved") - } - if (str.dataType == BinaryType) str.dataType else StringType - } + override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) override def children: Seq[Expression] = str :: pos :: len :: Nil - @inline - def slicePos(startPos: Int, sliceLen: Int, length: () => Int): (Int, Int) = { - // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and - // negative indices for start positions. If a start index i is greater than 0, it - // refers to element i-1 in the sequence. If a start index i is less than 0, it refers - // to the -ith element before the end of the sequence. If a start index i is 0, it - // refers to the first element. - - val start = startPos match { - case pos if pos > 0 => pos - 1 - case neg if neg < 0 => length() + neg - case _ => 0 - } - - val end = sliceLen match { - case max if max == Integer.MAX_VALUE => max - case x => start + x + override def eval(input: InternalRow): Any = { + val stringEval = str.eval(input) + if (stringEval != null) { + val posEval = pos.eval(input) + if (posEval != null) { + val lenEval = len.eval(input) + if (lenEval != null) { + stringEval.asInstanceOf[UTF8String] + .substringSQL(posEval.asInstanceOf[Int], lenEval.asInstanceOf[Int]) + } else { + null + } + } else { + null + } + } else { + null } - - (start, end) } - override def eval(input: InternalRow): Any = { - val string = str.eval(input) - val po = pos.eval(input) - val ln = len.eval(input) + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val strGen = str.gen(ctx) + val posGen = pos.gen(ctx) + val lenGen = len.gen(ctx) - if ((string == null) || (po == null) || (ln == null)) { - null - } else { - val start = po.asInstanceOf[Int] - val length = ln.asInstanceOf[Int] - string match { - case ba: Array[Byte] => - val (st, end) = slicePos(start, length, () => ba.length) - ba.slice(st, end) - case s: UTF8String => - val (st, end) = slicePos(start, length, () => s.numChars()) - s.substring(st, end) + val start = ctx.freshName("start") + val end = ctx.freshName("end") + + s""" + ${strGen.code} + boolean ${ev.isNull} = ${strGen.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${posGen.code} + if (!${posGen.isNull}) { + ${lenGen.code} + if (!${lenGen.isNull}) { + ${ev.primitive} = ${strGen.primitive} + .substringSQL(${posGen.primitive}, ${lenGen.primitive}); + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } } - } + """ } } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index ed354f7f877f1..946d355f1fc28 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -165,6 +165,18 @@ public UTF8String substring(final int start, final int until) { return fromBytes(bytes); } + public UTF8String substringSQL(int pos, int length) { + // Information regarding the pos calculation: + // Hive and SQL use one-based indexing for SUBSTR arguments but also accept zero and + // negative indices for start positions. If a start index i is greater than 0, it + // refers to element i-1 in the sequence. If a start index i is less than 0, it refers + // to the -ith element before the end of the sequence. If a start index i is 0, it + // refers to the first element. + int start = (pos > 0) ? pos -1 : ((pos < 0) ? numChars() + pos : 0); + int end = (length == Integer.MAX_VALUE) ? Integer.MAX_VALUE : start + length; + return substring(start, end); + } + /** * Returns whether this contains `substring` or not. */ diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 1f5572c509bdb..e2a5628ff4d93 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -272,6 +272,25 @@ public void pad() { fromString("数据砖头").rpad(12, fromString("孙行者"))); } + @Test + public void substringSQL() { + UTF8String e = fromString("example"); + assertEquals(e.substringSQL(0, 2), fromString("ex")); + assertEquals(e.substringSQL(1, 2), fromString("ex")); + assertEquals(e.substringSQL(0, 7), fromString("example")); + assertEquals(e.substringSQL(1, 2), fromString("ex")); + assertEquals(e.substringSQL(0, 100), fromString("example")); + assertEquals(e.substringSQL(1, 100), fromString("example")); + assertEquals(e.substringSQL(2, 2), fromString("xa")); + assertEquals(e.substringSQL(1, 6), fromString("exampl")); + assertEquals(e.substringSQL(2, 100), fromString("xample")); + assertEquals(e.substringSQL(0, 0), fromString("")); + assertEquals(e.substringSQL(100, 4), EMPTY_UTF8); + assertEquals(e.substringSQL(0, Integer.MAX_VALUE), fromString("example")); + assertEquals(e.substringSQL(1, Integer.MAX_VALUE), fromString("example")); + assertEquals(e.substringSQL(2, Integer.MAX_VALUE), fromString("xample")); + } + @Test public void split() { assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), -1), From 67570beed5950974126a91eacd48fd0fedfeb141 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 20 Jul 2015 22:48:13 -0700 Subject: [PATCH 12/32] [SPARK-9208][SQL] Remove variant of DataFrame string functions that accept column names. It can be ambiguous whether that is a string literal or a column name. cc marmbrus Author: Reynold Xin Closes #7556 from rxin/str-exprs and squashes the following commits: 92afa83 [Reynold Xin] [SPARK-9208][SQL] Remove variant of DataFrame string functions that accept column names. --- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../org/apache/spark/sql/functions.scala | 459 ++---------------- .../spark/sql/DataFrameFunctionsSuite.scala | 8 +- .../spark/sql/MathExpressionsSuite.scala | 1 - .../spark/sql/StringFunctionsSuite.scala | 59 +-- 5 files changed, 74 insertions(+), 455 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index fafdae07c92f0..9c45b196245da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -684,7 +684,7 @@ object CombineLimits extends Rule[LogicalPlan] { } /** - * Removes the inner [[CaseConversionExpression]] that are unnecessary because + * Removes the inner case conversion expressions that are unnecessary because * the inner conversion is overwritten by the outer one. */ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 41b25d1836481..8fa017610b63c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -69,7 +69,7 @@ object functions { def column(colName: String): Column = Column(colName) /** - * Convert a number from one base to another for the specified expressions + * Convert a number in string format from one base to another. * * @group math_funcs * @since 1.5.0 @@ -77,15 +77,6 @@ object functions { def conv(num: Column, fromBase: Int, toBase: Int): Column = Conv(num.expr, lit(fromBase).expr, lit(toBase).expr) - /** - * Convert a number from one base to another for the specified expressions - * - * @group math_funcs - * @since 1.5.0 - */ - def conv(numColName: String, fromBase: Int, toBase: Int): Column = - conv(Column(numColName), fromBase, toBase) - /** * Creates a [[Column]] of literal value. * @@ -627,14 +618,6 @@ object functions { */ def isNaN(e: Column): Column = IsNaN(e.expr) - /** - * Converts a string expression to lower case. - * - * @group normal_funcs - * @since 1.3.0 - */ - def lower(e: Column): Column = Lower(e.expr) - /** * A column expression that generates monotonically increasing 64-bit integers. * @@ -791,14 +774,6 @@ object functions { struct((colName +: colNames).map(col) : _*) } - /** - * Converts a string expression to upper case. - * - * @group normal_funcs - * @since 1.3.0 - */ - def upper(e: Column): Column = Upper(e.expr) - /** * Computes bitwise NOT. * @@ -1106,9 +1081,8 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def greatest(exprs: Column*): Column = if (exprs.length < 2) { - sys.error("GREATEST takes at least 2 parameters") - } else { + def greatest(exprs: Column*): Column = { + require(exprs.length > 1, "greatest requires at least 2 arguments.") Greatest(exprs.map(_.expr)) } @@ -1120,9 +1094,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def greatest(columnName: String, columnNames: String*): Column = if (columnNames.isEmpty) { - sys.error("GREATEST takes at least 2 parameters") - } else { + def greatest(columnName: String, columnNames: String*): Column = { greatest((columnName +: columnNames).map(Column.apply): _*) } @@ -1134,14 +1106,6 @@ object functions { */ def hex(column: Column): Column = Hex(column.expr) - /** - * Computes hex value of the given input. - * - * @group math_funcs - * @since 1.5.0 - */ - def hex(colName: String): Column = hex(Column(colName)) - /** * Inverse of hex. Interprets each pair of characters as a hexadecimal number * and converts to the byte representation of number. @@ -1151,15 +1115,6 @@ object functions { */ def unhex(column: Column): Column = Unhex(column.expr) - /** - * Inverse of hex. Interprets each pair of characters as a hexadecimal number - * and converts to the byte representation of number. - * - * @group math_funcs - * @since 1.5.0 - */ - def unhex(colName: String): Column = unhex(Column(colName)) - /** * Computes `sqrt(a^2^ + b^2^)` without intermediate overflow or underflow. * @@ -1233,9 +1188,8 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def least(exprs: Column*): Column = if (exprs.length < 2) { - sys.error("LEAST takes at least 2 parameters") - } else { + def least(exprs: Column*): Column = { + require(exprs.length > 1, "least requires at least 2 arguments.") Least(exprs.map(_.expr)) } @@ -1247,9 +1201,7 @@ object functions { * @since 1.5.0 */ @scala.annotation.varargs - def least(columnName: String, columnNames: String*): Column = if (columnNames.isEmpty) { - sys.error("LEAST takes at least 2 parameters") - } else { + def least(columnName: String, columnNames: String*): Column = { least((columnName +: columnNames).map(Column.apply): _*) } @@ -1639,7 +1591,8 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Calculates the MD5 digest and returns the value as a 32 character hex string. + * Calculates the MD5 digest of a binary column and returns the value + * as a 32 character hex string. * * @group misc_funcs * @since 1.5.0 @@ -1647,15 +1600,8 @@ object functions { def md5(e: Column): Column = Md5(e.expr) /** - * Calculates the MD5 digest and returns the value as a 32 character hex string. - * - * @group misc_funcs - * @since 1.5.0 - */ - def md5(columnName: String): Column = md5(Column(columnName)) - - /** - * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * Calculates the SHA-1 digest of a binary column and returns the value + * as a 40 character hex string. * * @group misc_funcs * @since 1.5.0 @@ -1663,15 +1609,11 @@ object functions { def sha1(e: Column): Column = Sha1(e.expr) /** - * Calculates the SHA-1 digest and returns the value as a 40 character hex string. + * Calculates the SHA-2 family of hash functions of a binary column and + * returns the value as a hex string. * - * @group misc_funcs - * @since 1.5.0 - */ - def sha1(columnName: String): Column = sha1(Column(columnName)) - - /** - * Calculates the SHA-2 family of hash functions and returns the value as a hex string. + * @param e column to compute SHA-2 on. + * @param numBits one of 224, 256, 384, or 512. * * @group misc_funcs * @since 1.5.0 @@ -1683,29 +1625,14 @@ object functions { } /** - * Calculates the SHA-2 family of hash functions and returns the value as a hex string. - * - * @group misc_funcs - * @since 1.5.0 - */ - def sha2(columnName: String, numBits: Int): Column = sha2(Column(columnName), numBits) - - /** - * Calculates the cyclic redundancy check value and returns the value as a bigint. + * Calculates the cyclic redundancy check value (CRC32) of a binary column and + * returns the value as a bigint. * * @group misc_funcs * @since 1.5.0 */ def crc32(e: Column): Column = Crc32(e.expr) - /** - * Calculates the cyclic redundancy check value and returns the value as a bigint. - * - * @group misc_funcs - * @since 1.5.0 - */ - def crc32(columnName: String): Column = crc32(Column(columnName)) - ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// @@ -1719,19 +1646,6 @@ object functions { @scala.annotation.varargs def concat(exprs: Column*): Column = Concat(exprs.map(_.expr)) - /** - * Concatenates input strings together into a single string. - * - * This is the variant of concat that takes in the column names. - * - * @group string_funcs - * @since 1.5.0 - */ - @scala.annotation.varargs - def concat(columnName: String, columnNames: String*): Column = { - concat((columnName +: columnNames).map(Column.apply): _*) - } - /** * Concatenates input strings together into a single string, using the given separator. * @@ -1743,19 +1657,6 @@ object functions { ConcatWs(Literal.create(sep, StringType) +: exprs.map(_.expr)) } - /** - * Concatenates input strings together into a single string, using the given separator. - * - * This is the variant of concat_ws that takes in the column names. - * - * @group string_funcs - * @since 1.5.0 - */ - @scala.annotation.varargs - def concat_ws(sep: String, columnName: String, columnNames: String*): Column = { - concat_ws(sep, (columnName +: columnNames).map(Column.apply) : _*) - } - /** * Computes the length of a given string / binary value. * @@ -1765,23 +1666,20 @@ object functions { def length(e: Column): Column = Length(e.expr) /** - * Computes the length of a given string / binary column. + * Converts a string expression to lower case. * * @group string_funcs - * @since 1.5.0 + * @since 1.3.0 */ - def length(columnName: String): Column = length(Column(columnName)) + def lower(e: Column): Column = Lower(e.expr) /** - * Formats the number X to a format like '#,###,###.##', rounded to d decimal places, - * and returns the result as a string. - * If d is 0, the result has no decimal point or fractional part. - * If d < 0, the result will be null. + * Converts a string expression to upper case. * * @group string_funcs - * @since 1.5.0 + * @since 1.3.0 */ - def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr) + def upper(e: Column): Column = Upper(e.expr) /** * Formats the number X to a format like '#,###,###.##', rounded to d decimal places, @@ -1792,57 +1690,31 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def format_number(columnXName: String, d: Int): Column = { - format_number(Column(columnXName), d) - } + def format_number(x: Column, d: Int): Column = FormatNumber(x.expr, lit(d).expr) /** - * Computes the Levenshtein distance of the two given strings. + * Computes the Levenshtein distance of the two given string columns. * @group string_funcs * @since 1.5.0 */ def levenshtein(l: Column, r: Column): Column = Levenshtein(l.expr, r.expr) - /** - * Computes the Levenshtein distance of the two given strings. - * @group string_funcs - * @since 1.5.0 - */ - def levenshtein(leftColumnName: String, rightColumnName: String): Column = - levenshtein(Column(leftColumnName), Column(rightColumnName)) - - /** - * Computes the numeric value of the first character of the specified string value. - * - * @group string_funcs - * @since 1.5.0 - */ - def ascii(e: Column): Column = Ascii(e.expr) - /** * Computes the numeric value of the first character of the specified string column. * * @group string_funcs * @since 1.5.0 */ - def ascii(columnName: String): Column = ascii(Column(columnName)) + def ascii(e: Column): Column = Ascii(e.expr) /** - * Trim the spaces from both ends for the specified string value. + * Trim the spaces from both ends for the specified string column. * * @group string_funcs * @since 1.5.0 */ def trim(e: Column): Column = StringTrim(e.expr) - /** - * Trim the spaces from both ends for the specified column. - * - * @group string_funcs - * @since 1.5.0 - */ - def trim(columnName: String): Column = trim(Column(columnName)) - /** * Trim the spaces from left end for the specified string value. * @@ -1851,14 +1723,6 @@ object functions { */ def ltrim(e: Column): Column = StringTrimLeft(e.expr) - /** - * Trim the spaces from left end for the specified column. - * - * @group string_funcs - * @since 1.5.0 - */ - def ltrim(columnName: String): Column = ltrim(Column(columnName)) - /** * Trim the spaces from right end for the specified string value. * @@ -1867,25 +1731,6 @@ object functions { */ def rtrim(e: Column): Column = StringTrimRight(e.expr) - /** - * Trim the spaces from right end for the specified column. - * - * @group string_funcs - * @since 1.5.0 - */ - def rtrim(columnName: String): Column = rtrim(Column(columnName)) - - /** - * Format strings in printf-style. - * - * @group string_funcs - * @since 1.5.0 - */ - @scala.annotation.varargs - def formatString(format: Column, arguments: Column*): Column = { - StringFormat((format +: arguments).map(_.expr): _*) - } - /** * Format strings in printf-style. * NOTE: `format` is the string value of the formatter, not column name. @@ -1898,18 +1743,6 @@ object functions { StringFormat(lit(format).expr +: arguNames.map(Column(_).expr): _*) } - /** - * Locate the position of the first occurrence of substr value in the given string. - * Returns null if either of the arguments are null. - * - * NOTE: The position is not zero based, but 1 based index, returns 0 if substr - * could not be found in str. - * - * @group string_funcs - * @since 1.5.0 - */ - def instr(substr: String, sub: String): Column = instr(Column(substr), Column(sub)) - /** * Locate the position of the first occurrence of substr column in the given string. * Returns null if either of the arguments are null. @@ -1920,10 +1753,10 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def instr(substr: Column, sub: Column): Column = StringInstr(substr.expr, sub.expr) + def instr(str: Column, substring: String): Column = StringInstr(str.expr, lit(substring).expr) /** - * Locate the position of the first occurrence of substr. + * Locate the position of the first occurrence of substr in a string column. * * NOTE: The position is not zero based, but 1 based index, returns 0 if substr * could not be found in str. @@ -1931,77 +1764,26 @@ object functions { * @group string_funcs * @since 1.5.0 */ - def locate(substr: String, str: String): Column = { - locate(Column(substr), Column(str)) + def locate(substr: String, str: Column): Column = { + new StringLocate(lit(substr).expr, str.expr) } /** - * Locate the position of the first occurrence of substr. + * Locate the position of the first occurrence of substr in a string column, after position pos. * - * NOTE: The position is not zero based, but 1 based index, returns 0 if substr + * NOTE: The position is not zero based, but 1 based index. returns 0 if substr * could not be found in str. * * @group string_funcs * @since 1.5.0 */ - def locate(substr: Column, str: Column): Column = { - new StringLocate(substr.expr, str.expr) + def locate(substr: String, str: Column, pos: Int): Column = { + StringLocate(lit(substr).expr, str.expr, lit(pos).expr) } /** - * Locate the position of the first occurrence of substr in a given string after position pos. - * - * NOTE: The position is not zero based, but 1 based index, returns 0 if substr - * could not be found in str. - * - * @group string_funcs - * @since 1.5.0 - */ - def locate(substr: String, str: String, pos: String): Column = { - locate(Column(substr), Column(str), Column(pos)) - } - - /** - * Locate the position of the first occurrence of substr in a given string after position pos. - * - * NOTE: The position is not zero based, but 1 based index, returns 0 if substr - * could not be found in str. - * - * @group string_funcs - * @since 1.5.0 - */ - def locate(substr: Column, str: Column, pos: Column): Column = { - StringLocate(substr.expr, str.expr, pos.expr) - } - - /** - * Locate the position of the first occurrence of substr in a given string after position pos. - * - * NOTE: The position is not zero based, but 1 based index, returns 0 if substr - * could not be found in str. - * - * @group string_funcs - * @since 1.5.0 - */ - def locate(substr: Column, str: Column, pos: Int): Column = { - StringLocate(substr.expr, str.expr, lit(pos).expr) - } - - /** - * Locate the position of the first occurrence of substr in a given string after position pos. - * - * NOTE: The position is not zero based, but 1 based index, returns 0 if substr - * could not be found in str. - * - * @group string_funcs - * @since 1.5.0 - */ - def locate(substr: String, str: String, pos: Int): Column = { - locate(Column(substr), Column(str), lit(pos)) - } - - /** - * Computes the specified value from binary to a base64 string. + * Computes the BASE64 encoding of a binary column and returns it as a string column. + * This is the reverse of unbase64. * * @group string_funcs * @since 1.5.0 @@ -2009,67 +1791,22 @@ object functions { def base64(e: Column): Column = Base64(e.expr) /** - * Computes the specified column from binary to a base64 string. - * - * @group string_funcs - * @since 1.5.0 - */ - def base64(columnName: String): Column = base64(Column(columnName)) - - /** - * Computes the specified value from a base64 string to binary. + * Decodes a BASE64 encoded string column and returns it as a binary column. + * This is the reverse of base64. * * @group string_funcs * @since 1.5.0 */ def unbase64(e: Column): Column = UnBase64(e.expr) - /** - * Computes the specified column from a base64 string to binary. - * - * @group string_funcs - * @since 1.5.0 - */ - def unbase64(columnName: String): Column = unbase64(Column(columnName)) - /** * Left-padded with pad to a length of len. * * @group string_funcs * @since 1.5.0 */ - def lpad(str: String, len: String, pad: String): Column = { - lpad(Column(str), Column(len), Column(pad)) - } - - /** - * Left-padded with pad to a length of len. - * - * @group string_funcs - * @since 1.5.0 - */ - def lpad(str: Column, len: Column, pad: Column): Column = { - StringLPad(str.expr, len.expr, pad.expr) - } - - /** - * Left-padded with pad to a length of len. - * - * @group string_funcs - * @since 1.5.0 - */ - def lpad(str: Column, len: Int, pad: Column): Column = { - StringLPad(str.expr, lit(len).expr, pad.expr) - } - - /** - * Left-padded with pad to a length of len. - * - * @group string_funcs - * @since 1.5.0 - */ - def lpad(str: String, len: Int, pad: String): Column = { - lpad(Column(str), len, Column(pad)) + def lpad(str: Column, len: Int, pad: String): Column = { + StringLPad(str.expr, lit(len).expr, lit(pad).expr) } /** @@ -2082,18 +1819,6 @@ object functions { */ def encode(value: Column, charset: String): Column = Encode(value.expr, lit(charset).expr) - /** - * Computes the first argument into a binary from a string using the provided character set - * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). - * If either argument is null, the result will also be null. - * NOTE: charset represents the string value of the character set, not the column name. - * - * @group string_funcs - * @since 1.5.0 - */ - def encode(columnName: String, charset: String): Column = - encode(Column(columnName), charset) - /** * Computes the first argument into a string from a binary using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). @@ -2104,106 +1829,24 @@ object functions { */ def decode(value: Column, charset: String): Column = Decode(value.expr, lit(charset).expr) - /** - * Computes the first argument into a string from a binary using the provided character set - * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). - * If either argument is null, the result will also be null. - * NOTE: charset represents the string value of the character set, not the column name. - * - * @group string_funcs - * @since 1.5.0 - */ - def decode(columnName: String, charset: String): Column = - decode(Column(columnName), charset) - - /** - * Right-padded with pad to a length of len. - * - * @group string_funcs - * @since 1.5.0 - */ - def rpad(str: String, len: String, pad: String): Column = { - rpad(Column(str), Column(len), Column(pad)) - } - - /** - * Right-padded with pad to a length of len. - * - * @group string_funcs - * @since 1.5.0 - */ - def rpad(str: Column, len: Column, pad: Column): Column = { - StringRPad(str.expr, len.expr, pad.expr) - } - - /** - * Right-padded with pad to a length of len. - * - * @group string_funcs - * @since 1.5.0 - */ - def rpad(str: String, len: Int, pad: String): Column = { - rpad(Column(str), len, Column(pad)) - } - /** * Right-padded with pad to a length of len. * * @group string_funcs * @since 1.5.0 */ - def rpad(str: Column, len: Int, pad: Column): Column = { - StringRPad(str.expr, lit(len).expr, pad.expr) - } - - /** - * Repeat the string value of the specified column n times. - * - * @group string_funcs - * @since 1.5.0 - */ - def repeat(strColumn: String, timesColumn: String): Column = { - repeat(Column(strColumn), Column(timesColumn)) + def rpad(str: Column, len: Int, pad: String): Column = { + StringRPad(str.expr, lit(len).expr, lit(pad).expr) } /** - * Repeat the string expression value n times. + * Repeats a string column n times, and returns it as a new string column. * * @group string_funcs * @since 1.5.0 */ - def repeat(str: Column, times: Column): Column = { - StringRepeat(str.expr, times.expr) - } - - /** - * Repeat the string value of the specified column n times. - * - * @group string_funcs - * @since 1.5.0 - */ - def repeat(strColumn: String, times: Int): Column = { - repeat(Column(strColumn), times) - } - - /** - * Repeat the string expression value n times. - * - * @group string_funcs - * @since 1.5.0 - */ - def repeat(str: Column, times: Int): Column = { - StringRepeat(str.expr, lit(times).expr) - } - - /** - * Splits str around pattern (pattern is a regular expression). - * - * @group string_funcs - * @since 1.5.0 - */ - def split(strColumnName: String, pattern: String): Column = { - split(Column(strColumnName), pattern) + def repeat(str: Column, n: Int): Column = { + StringRepeat(str.expr, lit(n).expr) } /** @@ -2217,16 +1860,6 @@ object functions { StringSplit(str.expr, lit(pattern).expr) } - /** - * Reversed the string for the specified column. - * - * @group string_funcs - * @since 1.5.0 - */ - def reverse(str: String): Column = { - reverse(Column(str)) - } - /** * Reversed the string for the specified value. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 29f1197a8543c..8d2ff2f9690d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -160,7 +160,7 @@ class DataFrameFunctionsSuite extends QueryTest { test("misc md5 function") { val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") checkAnswer( - df.select(md5($"a"), md5("b")), + df.select(md5($"a"), md5($"b")), Row("902fbdd2b1df0c4f70b4a5d23525e932", "6ac1e56bc78f031059be7be854522c4c")) checkAnswer( @@ -171,7 +171,7 @@ class DataFrameFunctionsSuite extends QueryTest { test("misc sha1 function") { val df = Seq(("ABC", "ABC".getBytes)).toDF("a", "b") checkAnswer( - df.select(sha1($"a"), sha1("b")), + df.select(sha1($"a"), sha1($"b")), Row("3c01bdbb26f358bab27f267924aa2c9a03fcfdb8", "3c01bdbb26f358bab27f267924aa2c9a03fcfdb8")) val dfEmpty = Seq(("", "".getBytes)).toDF("a", "b") @@ -183,7 +183,7 @@ class DataFrameFunctionsSuite extends QueryTest { test("misc sha2 function") { val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") checkAnswer( - df.select(sha2($"a", 256), sha2("b", 256)), + df.select(sha2($"a", 256), sha2($"b", 256)), Row("b5d4045c3f466fa91fe2cc6abe79232a1a57cdf104f7a26e716e0a1e2789df78", "7192385c3c0605de55bb9476ce1d90748190ecb32a8eed7f5207b30cf6a1fe89")) @@ -200,7 +200,7 @@ class DataFrameFunctionsSuite extends QueryTest { test("misc crc32 function") { val df = Seq(("ABC", Array[Byte](1, 2, 3, 4, 5, 6))).toDF("a", "b") checkAnswer( - df.select(crc32($"a"), crc32("b")), + df.select(crc32($"a"), crc32($"b")), Row(2743272264L, 2180413220L)) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index a51523f1a7a0f..21256704a5b16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -176,7 +176,6 @@ class MathExpressionsSuite extends QueryTest { test("conv") { val df = Seq(("333", 10, 2)).toDF("num", "fromBase", "toBase") checkAnswer(df.select(conv('num, 10, 16)), Row("14D")) - checkAnswer(df.select(conv("num", 10, 16)), Row("14D")) checkAnswer(df.select(conv(lit(100), 2, 16)), Row("4")) checkAnswer(df.select(conv(lit(3122234455L), 10, 16)), Row("BA198457")) checkAnswer(df.selectExpr("conv(num, fromBase, toBase)"), Row("101001101")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 413f3858d6764..4551192b157ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -52,14 +52,14 @@ class StringFunctionsSuite extends QueryTest { test("string Levenshtein distance") { val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r") - checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1))) + checkAnswer(df.select(levenshtein($"l", $"r")), Seq(Row(3), Row(1))) checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1))) } test("string ascii function") { val df = Seq(("abc", "")).toDF("a", "b") checkAnswer( - df.select(ascii($"a"), ascii("b")), + df.select(ascii($"a"), ascii($"b")), Row(97, 0)) checkAnswer( @@ -71,8 +71,8 @@ class StringFunctionsSuite extends QueryTest { val bytes = Array[Byte](1, 2, 3, 4) val df = Seq((bytes, "AQIDBA==")).toDF("a", "b") checkAnswer( - df.select(base64("a"), base64($"a"), unbase64("b"), unbase64($"b")), - Row("AQIDBA==", "AQIDBA==", bytes, bytes)) + df.select(base64($"a"), unbase64($"b")), + Row("AQIDBA==", bytes)) checkAnswer( df.selectExpr("base64(a)", "unbase64(b)"), @@ -85,12 +85,8 @@ class StringFunctionsSuite extends QueryTest { // non ascii characters are not allowed in the code, so we disable the scalastyle here. val df = Seq(("大千世界", "utf-8", bytes)).toDF("a", "b", "c") checkAnswer( - df.select( - encode($"a", "utf-8"), - encode("a", "utf-8"), - decode($"c", "utf-8"), - decode("c", "utf-8")), - Row(bytes, bytes, "大千世界", "大千世界")) + df.select(encode($"a", "utf-8"), decode($"c", "utf-8")), + Row(bytes, "大千世界")) checkAnswer( df.selectExpr("encode(a, 'utf-8')", "decode(c, 'utf-8')"), @@ -114,8 +110,8 @@ class StringFunctionsSuite extends QueryTest { val df = Seq(("aa%d%s", 123, "cc")).toDF("a", "b", "c") checkAnswer( - df.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), - Row("aa123cc", "aa123cc")) + df.select(formatString("aa%d%s", "b", "c")), + Row("aa123cc")) checkAnswer( df.selectExpr("printf(a, b, c)"), @@ -126,8 +122,8 @@ class StringFunctionsSuite extends QueryTest { val df = Seq(("aaads", "aa", "zz")).toDF("a", "b", "c") checkAnswer( - df.select(instr($"a", $"b"), instr("a", "b")), - Row(1, 1)) + df.select(instr($"a", "aa")), + Row(1)) checkAnswer( df.selectExpr("instr(a, b)"), @@ -138,10 +134,8 @@ class StringFunctionsSuite extends QueryTest { val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d") checkAnswer( - df.select( - locate($"b", $"a"), locate("b", "a"), locate($"b", $"a", 1), - locate("b", "a", 1), locate($"b", $"a", $"d"), locate("b", "a", "d")), - Row(1, 1, 2, 2, 2, 2)) + df.select(locate("aa", $"a"), locate("aa", $"a", 1)), + Row(1, 2)) checkAnswer( df.selectExpr("locate(b, a)", "locate(b, a, d)"), @@ -152,10 +146,8 @@ class StringFunctionsSuite extends QueryTest { val df = Seq(("hi", 5, "??")).toDF("a", "b", "c") checkAnswer( - df.select( - lpad($"a", $"b", $"c"), rpad("a", "b", "c"), - lpad($"a", 1, $"c"), rpad("a", 1, "c")), - Row("???hi", "hi???", "h", "h")) + df.select(lpad($"a", 1, "c"), lpad($"a", 5, "??"), rpad($"a", 1, "c"), rpad($"a", 5, "??")), + Row("h", "???hi", "h", "hi???")) checkAnswer( df.selectExpr("lpad(a, b, c)", "rpad(a, b, c)", "lpad(a, 1, c)", "rpad(a, 1, c)"), @@ -166,9 +158,8 @@ class StringFunctionsSuite extends QueryTest { val df = Seq(("hi", 2)).toDF("a", "b") checkAnswer( - df.select( - repeat($"a", 2), repeat("a", 2), repeat($"a", $"b"), repeat("a", "b")), - Row("hihi", "hihi", "hihi", "hihi")) + df.select(repeat($"a", 2)), + Row("hihi")) checkAnswer( df.selectExpr("repeat(a, 2)", "repeat(a, b)"), @@ -179,7 +170,7 @@ class StringFunctionsSuite extends QueryTest { val df = Seq(("hi", "hhhi")).toDF("a", "b") checkAnswer( - df.select(reverse($"a"), reverse("b")), + df.select(reverse($"a"), reverse($"b")), Row("ih", "ihhh")) checkAnswer( @@ -199,10 +190,8 @@ class StringFunctionsSuite extends QueryTest { val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b") checkAnswer( - df.select( - split($"a", "[1-9]+"), - split("a", "[1-9]+")), - Row(Seq("aa", "bb", "cc"), Seq("aa", "bb", "cc"))) + df.select(split($"a", "[1-9]+")), + Row(Seq("aa", "bb", "cc"))) checkAnswer( df.selectExpr("split(a, '[1-9]+')"), @@ -212,8 +201,8 @@ class StringFunctionsSuite extends QueryTest { test("string / binary length function") { val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123)).toDF("a", "b", "c") checkAnswer( - df.select(length($"a"), length("a"), length($"b"), length("b")), - Row(3, 3, 4, 4)) + df.select(length($"a"), length($"b")), + Row(3, 4)) checkAnswer( df.selectExpr("length(a)", "length(b)"), @@ -243,10 +232,8 @@ class StringFunctionsSuite extends QueryTest { "h") // decimal 7.128381 checkAnswer( - df.select( - format_number($"f", 4), - format_number("f", 4)), - Row("5.0000", "5.0000")) + df.select(format_number($"f", 4)), + Row("5.0000")) checkAnswer( df.selectExpr("format_number(b, e)"), // convert the 1st argument to integer From 48f8fd46b32973f1f3b865da80345698cb1a71c7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 20 Jul 2015 23:28:35 -0700 Subject: [PATCH 13/32] [SPARK-9023] [SQL] Followup for #7456 (Efficiency improvements for UnsafeRows in Exchange) This patch addresses code review feedback from #7456. Author: Josh Rosen Closes #7551 from JoshRosen/unsafe-exchange-followup and squashes the following commits: 76dbdf8 [Josh Rosen] Add comments + more methods to UnsafeRowSerializer 3d7a1f2 [Josh Rosen] Add writeToStream() method to UnsafeRow --- .../sql/catalyst/expressions/UnsafeRow.java | 33 ++++++++ .../sql/execution/UnsafeRowSerializer.scala | 80 +++++++++++++------ .../org/apache/spark/sql/UnsafeRowSuite.scala | 71 ++++++++++++++++ 3 files changed, 161 insertions(+), 23 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 8cd9e7bc60a03..6ce03a48e9538 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,6 +17,9 @@ package org.apache.spark.sql.catalyst.expressions; +import java.io.IOException; +import java.io.OutputStream; + import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.ObjectPool; import org.apache.spark.unsafe.PlatformDependent; @@ -371,6 +374,36 @@ public InternalRow copy() { } } + /** + * Write this UnsafeRow's underlying bytes to the given OutputStream. + * + * @param out the stream to write to. + * @param writeBuffer a byte array for buffering chunks of off-heap data while writing to the + * output stream. If this row is backed by an on-heap byte array, then this + * buffer will not be used and may be null. + */ + public void writeToStream(OutputStream out, byte[] writeBuffer) throws IOException { + if (baseObject instanceof byte[]) { + int offsetInByteArray = (int) (PlatformDependent.BYTE_ARRAY_OFFSET - baseOffset); + out.write((byte[]) baseObject, offsetInByteArray, sizeInBytes); + } else { + int dataRemaining = sizeInBytes; + long rowReadPosition = baseOffset; + while (dataRemaining > 0) { + int toTransfer = Math.min(writeBuffer.length, dataRemaining); + PlatformDependent.copyMemory( + baseObject, + rowReadPosition, + writeBuffer, + PlatformDependent.BYTE_ARRAY_OFFSET, + toTransfer); + out.write(writeBuffer, 0, toTransfer); + rowReadPosition += toTransfer; + dataRemaining -= toTransfer; + } + } + } + @Override public boolean anyNull() { return BitSetMethods.anySet(baseObject, baseOffset, bitSetWidthInBytes); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 19503ed00056c..318550e5ed899 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -49,8 +49,16 @@ private[sql] class UnsafeRowSerializer(numFields: Int) extends Serializer with S private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInstance { + /** + * Marks the end of a stream written with [[serializeStream()]]. + */ private[this] val EOF: Int = -1 + /** + * Serializes a stream of UnsafeRows. Within the stream, each record consists of a record + * length (stored as a 4-byte integer, written high byte first), followed by the record's bytes. + * The end of the stream is denoted by a record with the special length `EOF` (-1). + */ override def serializeStream(out: OutputStream): SerializationStream = new SerializationStream { private[this] var writeBuffer: Array[Byte] = new Array[Byte](4096) private[this] val dOut: DataOutputStream = new DataOutputStream(out) @@ -59,32 +67,31 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst val row = value.asInstanceOf[UnsafeRow] assert(row.getPool == null, "UnsafeRowSerializer does not support ObjectPool") dOut.writeInt(row.getSizeInBytes) - var dataRemaining: Int = row.getSizeInBytes - val baseObject = row.getBaseObject - var rowReadPosition: Long = row.getBaseOffset - while (dataRemaining > 0) { - val toTransfer: Int = Math.min(writeBuffer.length, dataRemaining) - PlatformDependent.copyMemory( - baseObject, - rowReadPosition, - writeBuffer, - PlatformDependent.BYTE_ARRAY_OFFSET, - toTransfer) - out.write(writeBuffer, 0, toTransfer) - rowReadPosition += toTransfer - dataRemaining -= toTransfer - } + row.writeToStream(out, writeBuffer) this } + override def writeKey[T: ClassTag](key: T): SerializationStream = { + // The key is only needed on the map side when computing partition ids. It does not need to + // be shuffled. assert(key.isInstanceOf[Int]) this } - override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = + + override def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = { + // This method is never called by shuffle code. throw new UnsupportedOperationException - override def writeObject[T: ClassTag](t: T): SerializationStream = + } + + override def writeObject[T: ClassTag](t: T): SerializationStream = { + // This method is never called by shuffle code. throw new UnsupportedOperationException - override def flush(): Unit = dOut.flush() + } + + override def flush(): Unit = { + dOut.flush() + } + override def close(): Unit = { writeBuffer = null dOut.writeInt(EOF) @@ -95,6 +102,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst override def deserializeStream(in: InputStream): DeserializationStream = { new DeserializationStream { private[this] val dIn: DataInputStream = new DataInputStream(in) + // 1024 is a default buffer size; this buffer will grow to accommodate larger rows private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) private[this] var row: UnsafeRow = new UnsafeRow() private[this] var rowTuple: (Int, UnsafeRow) = (0, row) @@ -126,14 +134,40 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst } } } - override def asIterator: Iterator[Any] = throw new UnsupportedOperationException - override def readKey[T: ClassTag](): T = throw new UnsupportedOperationException - override def readValue[T: ClassTag](): T = throw new UnsupportedOperationException - override def readObject[T: ClassTag](): T = throw new UnsupportedOperationException - override def close(): Unit = dIn.close() + + override def asIterator: Iterator[Any] = { + // This method is never called by shuffle code. + throw new UnsupportedOperationException + } + + override def readKey[T: ClassTag](): T = { + // We skipped serialization of the key in writeKey(), so just return a dummy value since + // this is going to be discarded anyways. + null.asInstanceOf[T] + } + + override def readValue[T: ClassTag](): T = { + val rowSize = dIn.readInt() + if (rowBuffer.length < rowSize) { + rowBuffer = new Array[Byte](rowSize) + } + ByteStreams.readFully(in, rowBuffer, 0, rowSize) + row.pointTo(rowBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, rowSize, null) + row.asInstanceOf[T] + } + + override def readObject[T: ClassTag](): T = { + // This method is never called by shuffle code. + throw new UnsupportedOperationException + } + + override def close(): Unit = { + dIn.close() + } } } + // These methods are never called by shuffle code. override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException override def deserialize[T: ClassTag](bytes: ByteBuffer): T = throw new UnsupportedOperationException diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala new file mode 100644 index 0000000000000..3854dc1b7a3d1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.io.ByteArrayOutputStream + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} +import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.memory.MemoryAllocator +import org.apache.spark.unsafe.types.UTF8String + +class UnsafeRowSuite extends SparkFunSuite { + test("writeToStream") { + val row = InternalRow.apply(UTF8String.fromString("hello"), UTF8String.fromString("world"), 123) + val arrayBackedUnsafeRow: UnsafeRow = + UnsafeProjection.create(Seq(StringType, StringType, IntegerType)).apply(row) + assert(arrayBackedUnsafeRow.getBaseObject.isInstanceOf[Array[Byte]]) + val bytesFromArrayBackedRow: Array[Byte] = { + val baos = new ByteArrayOutputStream() + arrayBackedUnsafeRow.writeToStream(baos, null) + baos.toByteArray + } + val bytesFromOffheapRow: Array[Byte] = { + val offheapRowPage = MemoryAllocator.UNSAFE.allocate(arrayBackedUnsafeRow.getSizeInBytes) + try { + PlatformDependent.copyMemory( + arrayBackedUnsafeRow.getBaseObject, + arrayBackedUnsafeRow.getBaseOffset, + offheapRowPage.getBaseObject, + offheapRowPage.getBaseOffset, + arrayBackedUnsafeRow.getSizeInBytes + ) + val offheapUnsafeRow: UnsafeRow = new UnsafeRow() + offheapUnsafeRow.pointTo( + offheapRowPage.getBaseObject, + offheapRowPage.getBaseOffset, + 3, // num fields + arrayBackedUnsafeRow.getSizeInBytes, + null // object pool + ) + assert(offheapUnsafeRow.getBaseObject === null) + val baos = new ByteArrayOutputStream() + val writeBuffer = new Array[Byte](1024) + offheapUnsafeRow.writeToStream(baos, writeBuffer) + baos.toByteArray + } finally { + MemoryAllocator.UNSAFE.free(offheapRowPage) + } + } + + assert(bytesFromArrayBackedRow === bytesFromOffheapRow) + } +} From 228ab65a4eeef8a42eb4713edf72b50590f63176 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Mon, 20 Jul 2015 23:31:08 -0700 Subject: [PATCH 14/32] [SPARK-9179] [BUILD] Use default primary author if unspecified Fixes feature introduced in #7508 to use the default value if nothing is specified in command line cc liancheng rxin pwendell Author: Shivaram Venkataraman Closes #7558 from shivaram/merge-script-fix and squashes the following commits: 7092141 [Shivaram Venkataraman] Use default primary author if unspecified --- dev/merge_spark_pr.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py index d586a57481aa1..ad4b76695c9ff 100755 --- a/dev/merge_spark_pr.py +++ b/dev/merge_spark_pr.py @@ -133,6 +133,8 @@ def merge_pr(pr_num, target_ref, title, body, pr_repo_desc): primary_author = raw_input( "Enter primary author in the format of \"name \" [%s]: " % distinct_authors[0]) + if primary_author == "": + primary_author = distinct_authors[0] commits = run_cmd(['git', 'log', 'HEAD..%s' % pr_branch_name, '--pretty=format:%h [%an] %s']).split("\n\n") From 1ddd0f2f1688560f88470e312b72af04364e2d49 Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Mon, 20 Jul 2015 23:33:07 -0700 Subject: [PATCH 15/32] [SPARK-9161][SQL] codegen FormatNumber Jira https://issues.apache.org/jira/browse/SPARK-9161 Author: Tarek Auel Closes #7545 from tarekauel/SPARK-9161 and squashes the following commits: 21425c8 [Tarek Auel] [SPARK-9161][SQL] codegen FormatNumber --- .../expressions/stringOperations.scala | 68 +++++++++++++++---- 1 file changed, 54 insertions(+), 14 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 438215e8e6e37..92fefe1585b23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -902,22 +902,15 @@ case class FormatNumber(x: Expression, d: Expression) @transient private val numberFormat: DecimalFormat = new DecimalFormat("") - override def eval(input: InternalRow): Any = { - val xObject = x.eval(input) - if (xObject == null) { + override protected def nullSafeEval(xObject: Any, dObject: Any): Any = { + val dValue = dObject.asInstanceOf[Int] + if (dValue < 0) { return null } - val dObject = d.eval(input) - - if (dObject == null || dObject.asInstanceOf[Int] < 0) { - return null - } - val dValue = dObject.asInstanceOf[Int] - if (dValue != lastDValue) { // construct a new DecimalFormat only if a new dValue - pattern.delete(0, pattern.length()) + pattern.delete(0, pattern.length) pattern.append("#,###,###,###,###,###,##0") // decimal place @@ -930,9 +923,10 @@ case class FormatNumber(x: Expression, d: Expression) pattern.append("0") } } - val dFormat = new DecimalFormat(pattern.toString()) - lastDValue = dValue; - numberFormat.applyPattern(dFormat.toPattern()) + val dFormat = new DecimalFormat(pattern.toString) + lastDValue = dValue + + numberFormat.applyPattern(dFormat.toPattern) } x.dataType match { @@ -947,6 +941,52 @@ case class FormatNumber(x: Expression, d: Expression) } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, (num, d) => { + + def typeHelper(p: String): String = { + x.dataType match { + case _ : DecimalType => s"""$p.toJavaBigDecimal()""" + case _ => s"$p" + } + } + + val sb = classOf[StringBuffer].getName + val df = classOf[DecimalFormat].getName + val lastDValue = ctx.freshName("lastDValue") + val pattern = ctx.freshName("pattern") + val numberFormat = ctx.freshName("numberFormat") + val i = ctx.freshName("i") + val dFormat = ctx.freshName("dFormat") + ctx.addMutableState("int", lastDValue, s"$lastDValue = -100;") + ctx.addMutableState(sb, pattern, s"$pattern = new $sb();") + ctx.addMutableState(df, numberFormat, s"""$numberFormat = new $df("");""") + + s""" + if ($d >= 0) { + $pattern.delete(0, $pattern.length()); + if ($d != $lastDValue) { + $pattern.append("#,###,###,###,###,###,##0"); + + if ($d > 0) { + $pattern.append("."); + for (int $i = 0; $i < $d; $i++) { + $pattern.append("0"); + } + } + $df $dFormat = new $df($pattern.toString()); + $lastDValue = $d; + $numberFormat.applyPattern($dFormat.toPattern()); + ${ev.primitive} = UTF8String.fromString($numberFormat.format(${typeHelper(num)})); + } + } else { + ${ev.primitive} = null; + ${ev.isNull} = true; + } + """ + }) + } + override def prettyName: String = "format_number" } From d38c5029a2ca845e2782096044a6412b653c9f95 Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Tue, 21 Jul 2015 15:08:44 +0800 Subject: [PATCH 16/32] [SPARK-9100] [SQL] Adds DataFrame reader/writer shortcut methods for ORC This PR adds DataFrame reader/writer shortcut methods for ORC in both Scala and Python. Author: Cheng Lian Closes #7444 from liancheng/spark-9100 and squashes the following commits: 284d043 [Cheng Lian] Fixes PySpark test cases and addresses PR comments e0b09fb [Cheng Lian] Adds DataFrame reader/writer shortcut methods for ORC --- python/pyspark/sql/readwriter.py | 44 ++++++++++++++++-- .../sql/orc_partitioned/._SUCCESS.crc | Bin 0 -> 8 bytes .../test_support/sql/orc_partitioned/_SUCCESS | 0 ...9af031-b970-49d6-ad39-30460a0be2c8.orc.crc | Bin 0 -> 12 bytes ...0-829af031-b970-49d6-ad39-30460a0be2c8.orc | Bin 0 -> 168 bytes ...9af031-b970-49d6-ad39-30460a0be2c8.orc.crc | Bin 0 -> 12 bytes ...0-829af031-b970-49d6-ad39-30460a0be2c8.orc | Bin 0 -> 168 bytes .../apache/spark/sql/DataFrameReader.scala | 9 ++++ .../apache/spark/sql/DataFrameWriter.scala | 12 +++++ .../hive/orc/OrcHadoopFsRelationSuite.scala | 3 +- .../hive/orc/OrcPartitionDiscoverySuite.scala | 14 +++--- .../spark/sql/hive/orc/OrcQuerySuite.scala | 12 ++--- .../apache/spark/sql/hive/orc/OrcTest.scala | 8 ++-- 13 files changed, 79 insertions(+), 23 deletions(-) create mode 100644 python/test_support/sql/orc_partitioned/._SUCCESS.crc create mode 100755 python/test_support/sql/orc_partitioned/_SUCCESS create mode 100644 python/test_support/sql/orc_partitioned/b=0/c=0/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc create mode 100755 python/test_support/sql/orc_partitioned/b=0/c=0/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc create mode 100644 python/test_support/sql/orc_partitioned/b=1/c=1/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc create mode 100755 python/test_support/sql/orc_partitioned/b=1/c=1/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 882a03090ec13..dea8bad79e187 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -146,14 +146,28 @@ def table(self, tableName): return self._df(self._jreader.table(tableName)) @since(1.4) - def parquet(self, *path): + def parquet(self, *paths): """Loads a Parquet file, returning the result as a :class:`DataFrame`. >>> df = sqlContext.read.parquet('python/test_support/sql/parquet_partitioned') >>> df.dtypes [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] """ - return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, path))) + return self._df(self._jreader.parquet(_to_seq(self._sqlContext._sc, paths))) + + @since(1.5) + def orc(self, path): + """ + Loads an ORC file, returning the result as a :class:`DataFrame`. + + ::Note: Currently ORC support is only available together with + :class:`HiveContext`. + + >>> df = hiveContext.read.orc('python/test_support/sql/orc_partitioned') + >>> df.dtypes + [('a', 'bigint'), ('b', 'int'), ('c', 'int')] + """ + return self._df(self._jreader.orc(path)) @since(1.4) def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPartitions=None, @@ -378,6 +392,29 @@ def parquet(self, path, mode=None, partitionBy=None): self.partitionBy(partitionBy) self._jwrite.parquet(path) + def orc(self, path, mode=None, partitionBy=None): + """Saves the content of the :class:`DataFrame` in ORC format at the specified path. + + ::Note: Currently ORC support is only available together with + :class:`HiveContext`. + + :param path: the path in any Hadoop supported file system + :param mode: specifies the behavior of the save operation when data already exists. + + * ``append``: Append contents of this :class:`DataFrame` to existing data. + * ``overwrite``: Overwrite existing data. + * ``ignore``: Silently ignore this operation if data already exists. + * ``error`` (default case): Throw an exception if data already exists. + :param partitionBy: names of partitioning columns + + >>> orc_df = hiveContext.read.orc('python/test_support/sql/orc_partitioned') + >>> orc_df.write.orc(os.path.join(tempfile.mkdtemp(), 'data')) + """ + self.mode(mode) + if partitionBy is not None: + self.partitionBy(partitionBy) + self._jwrite.orc(path) + @since(1.4) def jdbc(self, url, table, mode=None, properties={}): """Saves the content of the :class:`DataFrame` to a external database table via JDBC. @@ -408,7 +445,7 @@ def _test(): import os import tempfile from pyspark.context import SparkContext - from pyspark.sql import Row, SQLContext + from pyspark.sql import Row, SQLContext, HiveContext import pyspark.sql.readwriter os.chdir(os.environ["SPARK_HOME"]) @@ -420,6 +457,7 @@ def _test(): globs['os'] = os globs['sc'] = sc globs['sqlContext'] = SQLContext(sc) + globs['hiveContext'] = HiveContext(sc) globs['df'] = globs['sqlContext'].read.parquet('python/test_support/sql/parquet_partitioned') (failure_count, test_count) = doctest.testmod( diff --git a/python/test_support/sql/orc_partitioned/._SUCCESS.crc b/python/test_support/sql/orc_partitioned/._SUCCESS.crc new file mode 100644 index 0000000000000000000000000000000000000000..3b7b044936a890cd8d651d349a752d819d71d22c GIT binary patch literal 8 PcmYc;N@ieSU}69O2$TUk literal 0 HcmV?d00001 diff --git a/python/test_support/sql/orc_partitioned/_SUCCESS b/python/test_support/sql/orc_partitioned/_SUCCESS new file mode 100755 index 0000000000000..e69de29bb2d1d diff --git a/python/test_support/sql/orc_partitioned/b=0/c=0/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc b/python/test_support/sql/orc_partitioned/b=0/c=0/.part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc.crc new file mode 100644 index 0000000000000000000000000000000000000000..834cf0b7f227244a3ccda18809a0bb49d27b59d2 GIT binary patch literal 12 TcmYc;N@ieSU}CV!x?uzW5r_im literal 0 HcmV?d00001 diff --git a/python/test_support/sql/orc_partitioned/b=0/c=0/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc b/python/test_support/sql/orc_partitioned/b=0/c=0/part-r-00000-829af031-b970-49d6-ad39-30460a0be2c8.orc new file mode 100755 index 0000000000000000000000000000000000000000..49438018733565be297429b4f9349450441230f9 GIT binary patch literal 168 zcmeYda^_`V;9?PC;$Tzk9-$Hm^fGr7_ER>tdO)gOz`6{6JV5RXb@0hV&KsbZTiB@>>uPT3IK<`7W)7I literal 0 HcmV?d00001 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 0e37ad3e12e08..f1c1ddf898986 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -264,6 +264,15 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { } } + /** + * Loads an ORC file and returns the result as a [[DataFrame]]. + * + * @param path input path + * @since 1.5.0 + * @note Currently, this method can only be used together with `HiveContext`. + */ + def orc(path: String): DataFrame = format("orc").load(path) + /** * Returns the specified table as a [[DataFrame]]. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 5548b26cb8f80..3e7b9cd7976c3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -280,6 +280,18 @@ final class DataFrameWriter private[sql](df: DataFrame) { */ def parquet(path: String): Unit = format("parquet").save(path) + /** + * Saves the content of the [[DataFrame]] in ORC format at the specified path. + * This is equivalent to: + * {{{ + * format("orc").save(path) + * }}} + * + * @since 1.5.0 + * @note Currently, this method can only be used together with `HiveContext`. + */ + def orc(path: String): Unit = format("orc").save(path) + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index 080af5bb23c16..af3f468aaa5e9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -41,8 +41,7 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { .parallelize(for (i <- 1 to 3) yield (i, s"val_$i", p1)) .toDF("a", "b", "p1") .write - .format("orc") - .save(partitionDir.toString) + .orc(partitionDir.toString) } val dataSchemaWithPartition = diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index 3c2efe329bfd5..d463e8fd626f9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -49,13 +49,13 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { def makeOrcFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - data.toDF().write.format("orc").mode("overwrite").save(path.getCanonicalPath) + data.toDF().write.mode("overwrite").orc(path.getCanonicalPath) } def makeOrcFile[T <: Product: ClassTag: TypeTag]( df: DataFrame, path: File): Unit = { - df.write.format("orc").mode("overwrite").save(path.getCanonicalPath) + df.write.mode("overwrite").orc(path.getCanonicalPath) } protected def withTempTable(tableName: String)(f: => Unit): Unit = { @@ -90,7 +90,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read.format("orc").load(base.getCanonicalPath).registerTempTable("t") + read.orc(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -137,7 +137,7 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps)) } - read.format("orc").load(base.getCanonicalPath).registerTempTable("t") + read.orc(base.getCanonicalPath).registerTempTable("t") withTempTable("t") { checkAnswer( @@ -187,9 +187,8 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { } read - .format("orc") .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) - .load(base.getCanonicalPath) + .orc(base.getCanonicalPath) .registerTempTable("t") withTempTable("t") { @@ -230,9 +229,8 @@ class OrcPartitionDiscoverySuite extends QueryTest with BeforeAndAfterAll { } read - .format("orc") .option(ConfVars.DEFAULTPARTITIONNAME.varname, defaultPartitionName) - .load(base.getCanonicalPath) + .orc(base.getCanonicalPath) .registerTempTable("t") withTempTable("t") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index ca131faaeef05..744d462938141 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -63,14 +63,14 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withOrcFile(data) { file => checkAnswer( - sqlContext.read.format("orc").load(file), + sqlContext.read.orc(file), data.toDF().collect()) } } test("Read/write binary data") { withOrcFile(BinaryData("test".getBytes("utf8")) :: Nil) { file => - val bytes = read.format("orc").load(file).head().getAs[Array[Byte]](0) + val bytes = read.orc(file).head().getAs[Array[Byte]](0) assert(new String(bytes, "utf8") === "test") } } @@ -88,7 +88,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withOrcFile(data) { file => checkAnswer( - read.format("orc").load(file), + read.orc(file), data.toDF().collect()) } } @@ -158,7 +158,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withOrcFile(data) { file => checkAnswer( - read.format("orc").load(file), + read.orc(file), Row(Seq.fill(5)(null): _*)) } } @@ -310,7 +310,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { """.stripMargin) val errorMessage = intercept[AnalysisException] { - sqlContext.read.format("orc").load(path) + sqlContext.read.orc(path) }.getMessage assert(errorMessage.contains("Failed to discover schema from ORC files")) @@ -323,7 +323,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { |SELECT key, value FROM single """.stripMargin) - val df = sqlContext.read.format("orc").load(path) + val df = sqlContext.read.orc(path) assert(df.schema === singleRowDF.schema.asNullable) checkAnswer(df, singleRowDF) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 5daf691aa8c53..9d76d6503a3e6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -39,7 +39,7 @@ private[sql] trait OrcTest extends SQLTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sparkContext.parallelize(data).toDF().write.format("orc").save(file.getCanonicalPath) + sparkContext.parallelize(data).toDF().write.orc(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -51,7 +51,7 @@ private[sql] trait OrcTest extends SQLTestUtils { protected def withOrcDataFrame[T <: Product: ClassTag: TypeTag] (data: Seq[T]) (f: DataFrame => Unit): Unit = { - withOrcFile(data)(path => f(sqlContext.read.format("orc").load(path))) + withOrcFile(data)(path => f(sqlContext.read.orc(path))) } /** @@ -70,11 +70,11 @@ private[sql] trait OrcTest extends SQLTestUtils { protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - data.toDF().write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath) + data.toDF().write.mode(SaveMode.Overwrite).orc(path.getCanonicalPath) } protected def makeOrcFile[T <: Product: ClassTag: TypeTag]( df: DataFrame, path: File): Unit = { - df.write.format("orc").mode(SaveMode.Overwrite).save(path.getCanonicalPath) + df.write.mode(SaveMode.Overwrite).orc(path.getCanonicalPath) } } From 8c8f0ef59e12b6f13d5a0bf2d7bf1248b5c1e369 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 21 Jul 2015 00:48:07 -0700 Subject: [PATCH 17/32] [SPARK-8255] [SPARK-8256] [SQL] Add regex_extract/regex_replace Add expressions `regex_extract` & `regex_replace` Author: Cheng Hao Closes #7468 from chenghao-intel/regexp and squashes the following commits: e5ea476 [Cheng Hao] minor update for documentation ef96fd6 [Cheng Hao] update the code gen 72cf28f [Cheng Hao] Add more log for compilation error 4e11381 [Cheng Hao] Add regexp_replace / regexp_extract support --- python/pyspark/sql/functions.py | 30 +++ .../catalyst/analysis/FunctionRegistry.scala | 2 + .../expressions/codegen/CodeGenerator.scala | 5 +- .../expressions/stringOperations.scala | 217 +++++++++++++++++- .../expressions/ExpressionEvalHelper.scala | 1 - .../expressions/StringExpressionsSuite.scala | 35 +++ .../org/apache/spark/sql/functions.scala | 21 ++ .../spark/sql/StringFunctionsSuite.scala | 16 ++ 8 files changed, 323 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 031745a1c4d3b..3c134faa0a765 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -46,6 +46,8 @@ 'monotonicallyIncreasingId', 'rand', 'randn', + 'regexp_extract', + 'regexp_replace', 'sha1', 'sha2', 'sparkPartitionId', @@ -343,6 +345,34 @@ def levenshtein(left, right): return Column(jc) +@ignore_unicode_prefix +@since(1.5) +def regexp_extract(str, pattern, idx): + """Extract a specific(idx) group identified by a java regex, from the specified string column. + + >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) + >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect() + [Row(d=u'100')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def regexp_replace(str, pattern, replacement): + """Replace all substrings of the specified string value that match regexp with rep. + + >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) + >>> df.select(regexp_replace('str', '(\\d+)', '##').alias('d')).collect() + [Row(d=u'##-##')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement) + return Column(jc) + + @ignore_unicode_prefix @since(1.5) def md5(col): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 71e87b98d86fc..aec392379c186 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -161,6 +161,8 @@ object FunctionRegistry { expression[Lower]("lower"), expression[Length]("length"), expression[Levenshtein]("levenshtein"), + expression[RegExpExtract]("regexp_extract"), + expression[RegExpReplace]("regexp_replace"), expression[StringInstr]("instr"), expression[StringLocate]("locate"), expression[StringLPad]("lpad"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 606f770cb4f7b..319dcd1c04316 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -297,8 +297,9 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin evaluator.cook(code) } catch { case e: Exception => - logError(s"failed to compile:\n $code", e) - throw e + val msg = s"failed to compile:\n $code" + logError(msg, e) + throw new Exception(msg, e) } evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 92fefe1585b23..fe57d17f1ec14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.text.DecimalFormat import java.util.Locale -import java.util.regex.Pattern +import java.util.regex.{MatchResult, Pattern} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedException @@ -876,6 +876,221 @@ case class Encode(value: Expression, charset: Expression) } } +/** + * Replace all substrings of str that match regexp with rep. + * + * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. + */ +case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) + extends Expression with ImplicitCastInputTypes { + + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegex: UTF8String = _ + // last regex pattern, we cache it for performance concern + @transient private var pattern: Pattern = _ + // last replacement string, we don't want to convert a UTF8String => java.langString every time. + @transient private var lastReplacement: String = _ + @transient private var lastReplacementInUTF8: UTF8String = _ + // result buffer write by Matcher + @transient private val result: StringBuffer = new StringBuffer + + override def nullable: Boolean = subject.nullable || regexp.nullable || rep.nullable + override def foldable: Boolean = subject.foldable && regexp.foldable && rep.foldable + + override def eval(input: InternalRow): Any = { + val s = subject.eval(input) + if (null != s) { + val p = regexp.eval(input) + if (null != p) { + val r = rep.eval(input) + if (null != r) { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String] + pattern = Pattern.compile(lastRegex.toString) + } + if (!r.equals(lastReplacementInUTF8)) { + // replacement string changed + lastReplacementInUTF8 = r.asInstanceOf[UTF8String] + lastReplacement = lastReplacementInUTF8.toString + } + val m = pattern.matcher(s.toString()) + result.delete(0, result.length()) + + while (m.find) { + m.appendReplacement(result, lastReplacement) + } + m.appendTail(result) + + return UTF8String.fromString(result.toString) + } + } + } + + null + } + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) + override def children: Seq[Expression] = subject :: regexp :: rep :: Nil + override def prettyName: String = "regexp_replace" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val termLastRegex = ctx.freshName("lastRegex") + val termPattern = ctx.freshName("pattern") + + val termLastReplacement = ctx.freshName("lastReplacement") + val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8") + + val termResult = ctx.freshName("result") + + val classNameUTF8String = classOf[UTF8String].getCanonicalName + val classNamePattern = classOf[Pattern].getCanonicalName + val classNameString = classOf[java.lang.String].getCanonicalName + val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName + + ctx.addMutableState(classNameUTF8String, + termLastRegex, s"${termLastRegex} = null;") + ctx.addMutableState(classNamePattern, + termPattern, s"${termPattern} = null;") + ctx.addMutableState(classNameString, + termLastReplacement, s"${termLastReplacement} = null;") + ctx.addMutableState(classNameUTF8String, + termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") + ctx.addMutableState(classNameStringBuffer, + termResult, s"${termResult} = new $classNameStringBuffer();") + + val evalSubject = subject.gen(ctx) + val evalRegexp = regexp.gen(ctx) + val evalRep = rep.gen(ctx) + + s""" + ${evalSubject.code} + boolean ${ev.isNull} = ${evalSubject.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${evalSubject.isNull}) { + ${evalRegexp.code} + if (!${evalRegexp.isNull}) { + ${evalRep.code} + if (!${evalRep.isNull}) { + if (!${evalRegexp.primitive}.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = ${evalRegexp.primitive}; + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + } + if (!${evalRep.primitive}.equals(${termLastReplacementInUTF8})) { + // replacement string changed + ${termLastReplacementInUTF8} = ${evalRep.primitive}; + ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); + } + ${termResult}.delete(0, ${termResult}.length()); + ${classOf[java.util.regex.Matcher].getCanonicalName} m = + ${termPattern}.matcher(${evalSubject.primitive}.toString()); + + while (m.find()) { + m.appendReplacement(${termResult}, ${termLastReplacement}); + } + m.appendTail(${termResult}); + ${ev.primitive} = ${classNameUTF8String}.fromString(${termResult}.toString()); + ${ev.isNull} = false; + } + } + } + """ + } +} + +/** + * Extract a specific(idx) group identified by a Java regex. + * + * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. + */ +case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) + extends Expression with ImplicitCastInputTypes { + def this(s: Expression, r: Expression) = this(s, r, Literal(1)) + + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegex: UTF8String = _ + // last regex pattern, we cache it for performance concern + @transient private var pattern: Pattern = _ + + override def nullable: Boolean = subject.nullable || regexp.nullable || idx.nullable + override def foldable: Boolean = subject.foldable && regexp.foldable && idx.foldable + + override def eval(input: InternalRow): Any = { + val s = subject.eval(input) + if (null != s) { + val p = regexp.eval(input) + if (null != p) { + val r = idx.eval(input) + if (null != r) { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String] + pattern = Pattern.compile(lastRegex.toString) + } + val m = pattern.matcher(s.toString()) + if (m.find) { + val mr: MatchResult = m.toMatchResult + return UTF8String.fromString(mr.group(r.asInstanceOf[Int])) + } + return UTF8String.EMPTY_UTF8 + } + } + } + + null + } + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) + override def children: Seq[Expression] = subject :: regexp :: idx :: Nil + override def prettyName: String = "regexp_extract" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val termLastRegex = ctx.freshName("lastRegex") + val termPattern = ctx.freshName("pattern") + val classNameUTF8String = classOf[UTF8String].getCanonicalName + val classNamePattern = classOf[Pattern].getCanonicalName + + ctx.addMutableState(classNameUTF8String, termLastRegex, s"${termLastRegex} = null;") + ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") + + val evalSubject = subject.gen(ctx) + val evalRegexp = regexp.gen(ctx) + val evalIdx = idx.gen(ctx) + + s""" + ${ctx.javaType(dataType)} ${ev.primitive} = null; + boolean ${ev.isNull} = true; + ${evalSubject.code} + if (!${evalSubject.isNull}) { + ${evalRegexp.code} + if (!${evalRegexp.isNull}) { + ${evalIdx.code} + if (!${evalIdx.isNull}) { + if (!${evalRegexp.primitive}.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = ${evalRegexp.primitive}; + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + } + ${classOf[java.util.regex.Matcher].getCanonicalName} m = + ${termPattern}.matcher(${evalSubject.primitive}.toString()); + if (m.find()) { + ${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult(); + ${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive})); + ${ev.isNull} = false; + } else { + ${ev.primitive} = ${classNameUTF8String}.EMPTY_UTF8; + ${ev.isNull} = false; + } + } + } + } + """ + } +} + /** * Formats the number X to a format like '#,###,###.##', rounded to D decimal places, * and returns the result as a string. If D is 0, the result has no decimal point or diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 7a96044d35a09..6e17ffcda9dc4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -79,7 +79,6 @@ trait ExpressionEvalHelper { fail( s""" |Code generation of $expression failed: - |${evaluated.code} |$e """.stripMargin) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 67d97cd30b039..96c540ab36f08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -464,6 +464,41 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringSpace(s1), null, row2) } + test("RegexReplace") { + val row1 = create_row("100-200", "(\\d+)", "num") + val row2 = create_row("100-200", "(\\d+)", "###") + val row3 = create_row("100-200", "(-)", "###") + + val s = 's.string.at(0) + val p = 'p.string.at(1) + val r = 'r.string.at(2) + + val expr = RegExpReplace(s, p, r) + checkEvaluation(expr, "num-num", row1) + checkEvaluation(expr, "###-###", row2) + checkEvaluation(expr, "100###200", row3) + } + + test("RegexExtract") { + val row1 = create_row("100-200", "(\\d+)-(\\d+)", 1) + val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2) + val row3 = create_row("100-200", "(\\d+).*", 1) + val row4 = create_row("100-200", "([a-z])", 1) + + val s = 's.string.at(0) + val p = 'p.string.at(1) + val r = 'r.int.at(2) + + val expr = RegExpExtract(s, p, r) + checkEvaluation(expr, "100", row1) + checkEvaluation(expr, "200", row2) + checkEvaluation(expr, "100", row3) + checkEvaluation(expr, "", row4) // will not match anything, empty string get + + val expr1 = new RegExpExtract(s, p) + checkEvaluation(expr1, "100", row1) + } + test("SPLIT") { val s1 = 'a.string.at(0) val s2 = 'b.string.at(1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8fa017610b63c..6d60dae624b0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1781,6 +1781,27 @@ object functions { StringLocate(lit(substr).expr, str.expr, lit(pos).expr) } + + /** + * Extract a specific(idx) group identified by a java regex, from the specified string column. + * + * @group string_funcs + * @since 1.5.0 + */ + def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = { + RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr) + } + + /** + * Replace all substrings of the specified string value that match regexp with rep. + * + * @group string_funcs + * @since 1.5.0 + */ + def regexp_replace(e: Column, pattern: String, replacement: String): Column = { + RegExpReplace(e.expr, lit(pattern).expr, lit(replacement).expr) + } + /** * Computes the BASE64 encoding of a binary column and returns it as a string column. * This is the reverse of unbase64. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 4551192b157ff..d1f855903ca4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -56,6 +56,22 @@ class StringFunctionsSuite extends QueryTest { checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1))) } + test("string regex_replace / regex_extract") { + val df = Seq(("100-200", "")).toDF("a", "b") + + checkAnswer( + df.select( + regexp_replace($"a", "(\\d+)", "num"), + regexp_extract($"a", "(\\d+)-(\\d+)", 1)), + Row("num-num", "100")) + + checkAnswer( + df.selectExpr( + "regexp_replace(a, '(\\d+)', 'num')", + "regexp_extract(a, '(\\d+)-(\\d+)', 2)"), + Row("num-num", "200")) + } + test("string ascii function") { val df = Seq(("abc", "")).toDF("a", "b") checkAnswer( From 560c658a7462844c698b5bda09a4cfb4094fd65b Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 21 Jul 2015 00:53:20 -0700 Subject: [PATCH 18/32] [SPARK-8230][SQL] Add array/map size method Pull Request for: https://issues.apache.org/jira/browse/SPARK-8230 Primary issue resolved is to implement array/map size for Spark SQL. Code is ready for review by a committer. Chen Hao is on the JIRA ticket, but I don't know his username on github, rxin is also on JIRA ticket. Things to review: 1. Where to put added functions namespace wise, they seem to be part of a few operations on collections which includes `sort_array` and `array_contains`. Hence the name given `collectionOperations.scala` and `_collection_functions` in python. 2. In Python code, should it be in a `1.5.0` function array or in a collections array? 3. Are there any missing methods on the `Size` case class? Looks like many of these functions have generated Java code, is that also needed in this case? 4. Something else? Author: Pedro Rodriguez Author: Pedro Rodriguez Closes #7462 from EntilZha/SPARK-8230 and squashes the following commits: 9a442ae [Pedro Rodriguez] fixed functions and sorted __all__ 9aea3bb [Pedro Rodriguez] removed imports from python docs 15d4bf1 [Pedro Rodriguez] Added null test case and changed to nullSafeCodeGen d88247c [Pedro Rodriguez] removed python code bd5f0e4 [Pedro Rodriguez] removed duplicate function from rebase/merge 59931b4 [Pedro Rodriguez] fixed compile bug instroduced when merging c187175 [Pedro Rodriguez] updated code to add size to __all__ directly and removed redundent pretty print 130839f [Pedro Rodriguez] fixed failing test aa9bade [Pedro Rodriguez] fix style e093473 [Pedro Rodriguez] updated python code with docs, switched classes/traits implemented, added (failing) expression tests 0449377 [Pedro Rodriguez] refactored code to use better abstract classes/traits and implementations 9a1a2ff [Pedro Rodriguez] added unit tests for map size 2bfbcb6 [Pedro Rodriguez] added unit test for size 20df2b4 [Pedro Rodriguez] Finished working version of size function and added it to python b503e75 [Pedro Rodriguez] First attempt at implementing size for maps and arrays 99a6a5c [Pedro Rodriguez] fixed failing test cac75ac [Pedro Rodriguez] fix style 933d843 [Pedro Rodriguez] updated python code with docs, switched classes/traits implemented, added (failing) expression tests 42bb7d4 [Pedro Rodriguez] refactored code to use better abstract classes/traits and implementations f9c3b8a [Pedro Rodriguez] added unit tests for map size 2515d9f [Pedro Rodriguez] added documentation 0e60541 [Pedro Rodriguez] added unit test for size acf9853 [Pedro Rodriguez] Finished working version of size function and added it to python 84a5d38 [Pedro Rodriguez] First attempt at implementing size for maps and arrays --- python/pyspark/sql/functions.py | 15 ++++++ .../catalyst/analysis/FunctionRegistry.scala | 4 +- .../expressions/collectionOperations.scala | 37 +++++++++++++++ .../CollectionFunctionsSuite.scala | 46 +++++++++++++++++++ .../org/apache/spark/sql/functions.scala | 20 ++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 31 +++++++++++++ 6 files changed, 152 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3c134faa0a765..719e623a1a11f 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -50,6 +50,7 @@ 'regexp_replace', 'sha1', 'sha2', + 'size', 'sparkPartitionId', 'struct', 'udf', @@ -825,6 +826,20 @@ def weekofyear(col): return Column(sc._jvm.functions.weekofyear(col)) +@since(1.5) +def size(col): + """ + Collection function: returns the length of the array or map stored in the column. + :param col: name of column or expression + + >>> df = sqlContext.createDataFrame([([1, 2, 3],),([1],),([],)], ['data']) + >>> df.select(size(df.data)).collect() + [Row(size(data)=3), Row(size(data)=1), Row(size(data)=0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.size(_to_java_column(col))) + + class UserDefinedFunction(object): """ User defined function in Python diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index aec392379c186..13523720daff0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -195,8 +195,10 @@ object FunctionRegistry { expression[Quarter]("quarter"), expression[Second]("second"), expression[WeekOfYear]("weekofyear"), - expression[Year]("year") + expression[Year]("year"), + // collection functions + expression[Size]("size") ) val builtin: FunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala new file mode 100644 index 0000000000000..2d92dcf23a86e --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.types._ + +/** + * Given an array or map, returns its size. + */ +case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def dataType: DataType = IntegerType + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) + + override def nullSafeEval(value: Any): Int = child.dataType match { + case ArrayType(_, _) => value.asInstanceOf[Seq[Any]].size + case MapType(_, _, _) => value.asInstanceOf[Map[Any, Any]].size + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).size();") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala new file mode 100644 index 0000000000000..28c41b57169f9 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + + +class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("Array and Map Size") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) + val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) + + checkEvaluation(Size(a0), 3) + checkEvaluation(Size(a1), 0) + checkEvaluation(Size(a2), 2) + + val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType)) + val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) + val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType)) + + checkEvaluation(Size(m0), 2) + checkEvaluation(Size(m1), 0) + checkEvaluation(Size(m2), 1) + + checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null) + checkEvaluation(Literal.create(null, ArrayType(StringType)), null) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6d60dae624b0c..60b089180c876 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -42,6 +42,7 @@ import org.apache.spark.util.Utils * @groupname misc_funcs Misc functions * @groupname window_funcs Window functions * @groupname string_funcs String functions + * @groupname collection_funcs Collection functions * @groupname Ungrouped Support functions for DataFrames. * @since 1.3.0 */ @@ -2053,6 +2054,25 @@ object functions { */ def weekofyear(columnName: String): Column = weekofyear(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// + // Collection functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Returns length of array or map + * @group collection_funcs + * @since 1.5.0 + */ + def size(columnName: String): Column = size(Column(columnName)) + + /** + * Returns length of array or map + * @group collection_funcs + * @since 1.5.0 + */ + def size(column: Column): Column = Size(column.expr) + + ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 8d2ff2f9690d6..1baec5d37699d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -267,4 +267,35 @@ class DataFrameFunctionsSuite extends QueryTest { ) } + test("array size function") { + val df = Seq( + (Array[Int](1, 2), "x"), + (Array[Int](), "y"), + (Array[Int](1, 2, 3), "z") + ).toDF("a", "b") + checkAnswer( + df.select(size("a")), + Seq(Row(2), Row(0), Row(3)) + ) + checkAnswer( + df.selectExpr("size(a)"), + Seq(Row(2), Row(0), Row(3)) + ) + } + + test("map size function") { + val df = Seq( + (Map[Int, Int](1 -> 1, 2 -> 2), "x"), + (Map[Int, Int](), "y"), + (Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z") + ).toDF("a", "b") + checkAnswer( + df.select(size("a")), + Seq(Row(2), Row(0), Row(3)) + ) + checkAnswer( + df.selectExpr("size(a)"), + Seq(Row(2), Row(0), Row(3)) + ) + } } From ae230596b866d8e369bd061256c4cc569dba430a Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 21 Jul 2015 00:56:57 -0700 Subject: [PATCH 19/32] [SPARK-9173][SQL]UnionPushDown should also support Intersect and Except JIRA: https://issues.apache.org/jira/browse/SPARK-9173 Author: Yijie Shen Closes #7540 from yjshen/union_pushdown and squashes the following commits: 278510a [Yijie Shen] rename UnionPushDown to SetOperationPushDown 91741c1 [Yijie Shen] Add UnionPushDown support for intersect and except --- .../sql/catalyst/optimizer/Optimizer.scala | 47 +++++++++-- .../optimizer/SetOperationPushDownSuite.scala | 82 +++++++++++++++++++ .../optimizer/UnionPushdownSuite.scala | 61 -------------- 3 files changed, 120 insertions(+), 70 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9c45b196245da..e42f0b9a247e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -40,7 +40,7 @@ object DefaultOptimizer extends Optimizer { ReplaceDistinctWithAggregate) :: Batch("Operator Optimizations", FixedPoint(100), // Operator push down - UnionPushDown, + SetOperationPushDown, SamplePushDown, PushPredicateThroughJoin, PushPredicateThroughProject, @@ -84,23 +84,24 @@ object SamplePushDown extends Rule[LogicalPlan] { } /** - * Pushes operations to either side of a Union. + * Pushes operations to either side of a Union, Intersect or Except. */ -object UnionPushDown extends Rule[LogicalPlan] { +object SetOperationPushDown extends Rule[LogicalPlan] { /** * Maps Attributes from the left side to the corresponding Attribute on the right side. */ - private def buildRewrites(union: Union): AttributeMap[Attribute] = { - assert(union.left.output.size == union.right.output.size) + private def buildRewrites(bn: BinaryNode): AttributeMap[Attribute] = { + assert(bn.isInstanceOf[Union] || bn.isInstanceOf[Intersect] || bn.isInstanceOf[Except]) + assert(bn.left.output.size == bn.right.output.size) - AttributeMap(union.left.output.zip(union.right.output)) + AttributeMap(bn.left.output.zip(bn.right.output)) } /** - * Rewrites an expression so that it can be pushed to the right side of a Union operator. - * This method relies on the fact that the output attributes of a union are always equal - * to the left child's output. + * Rewrites an expression so that it can be pushed to the right side of a + * Union, Intersect or Except operator. This method relies on the fact that the output attributes + * of a union/intersect/except are always equal to the left child's output. */ private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = { val result = e transform { @@ -126,6 +127,34 @@ object UnionPushDown extends Rule[LogicalPlan] { Union( Project(projectList, left), Project(projectList.map(pushToRight(_, rewrites)), right)) + + // Push down filter into intersect + case Filter(condition, i @ Intersect(left, right)) => + val rewrites = buildRewrites(i) + Intersect( + Filter(condition, left), + Filter(pushToRight(condition, rewrites), right)) + + // Push down projection into intersect + case Project(projectList, i @ Intersect(left, right)) => + val rewrites = buildRewrites(i) + Intersect( + Project(projectList, left), + Project(projectList.map(pushToRight(_, rewrites)), right)) + + // Push down filter into except + case Filter(condition, e @ Except(left, right)) => + val rewrites = buildRewrites(e) + Except( + Filter(condition, left), + Filter(pushToRight(condition, rewrites), right)) + + // Push down projection into except + case Project(projectList, e @ Except(left, right)) => + val rewrites = buildRewrites(e) + Except( + Project(projectList, left), + Project(projectList.map(pushToRight(_, rewrites)), right)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala new file mode 100644 index 0000000000000..49c979bc7d72c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.dsl.expressions._ + +class SetOperationPushDownSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Subqueries", Once, + EliminateSubQueries) :: + Batch("Union Pushdown", Once, + SetOperationPushDown) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.int, 'c.int) + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + val testUnion = Union(testRelation, testRelation2) + val testIntersect = Intersect(testRelation, testRelation2) + val testExcept = Except(testRelation, testRelation2) + + test("union/intersect/except: filter to each side") { + val unionQuery = testUnion.where('a === 1) + val intersectQuery = testIntersect.where('b < 10) + val exceptQuery = testExcept.where('c >= 5) + + val unionOptimized = Optimize.execute(unionQuery.analyze) + val intersectOptimized = Optimize.execute(intersectQuery.analyze) + val exceptOptimized = Optimize.execute(exceptQuery.analyze) + + val unionCorrectAnswer = + Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze + val intersectCorrectAnswer = + Intersect(testRelation.where('b < 10), testRelation2.where('e < 10)).analyze + val exceptCorrectAnswer = + Except(testRelation.where('c >= 5), testRelation2.where('f >= 5)).analyze + + comparePlans(unionOptimized, unionCorrectAnswer) + comparePlans(intersectOptimized, intersectCorrectAnswer) + comparePlans(exceptOptimized, exceptCorrectAnswer) + } + + test("union/intersect/except: project to each side") { + val unionQuery = testUnion.select('a) + val intersectQuery = testIntersect.select('b, 'c) + val exceptQuery = testExcept.select('a, 'b, 'c) + + val unionOptimized = Optimize.execute(unionQuery.analyze) + val intersectOptimized = Optimize.execute(intersectQuery.analyze) + val exceptOptimized = Optimize.execute(exceptQuery.analyze) + + val unionCorrectAnswer = + Union(testRelation.select('a), testRelation2.select('d)).analyze + val intersectCorrectAnswer = + Intersect(testRelation.select('b, 'c), testRelation2.select('e, 'f)).analyze + val exceptCorrectAnswer = + Except(testRelation.select('a, 'b, 'c), testRelation2.select('d, 'e, 'f)).analyze + + comparePlans(unionOptimized, unionCorrectAnswer) + comparePlans(intersectOptimized, intersectCorrectAnswer) + comparePlans(exceptOptimized, exceptCorrectAnswer) } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala deleted file mode 100644 index ec379489a6d1e..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala +++ /dev/null @@ -1,61 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ - -class UnionPushDownSuite extends PlanTest { - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch("Subqueries", Once, - EliminateSubQueries) :: - Batch("Union Pushdown", Once, - UnionPushDown) :: Nil - } - - val testRelation = LocalRelation('a.int, 'b.int, 'c.int) - val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) - val testUnion = Union(testRelation, testRelation2) - - test("union: filter to each side") { - val query = testUnion.where('a === 1) - - val optimized = Optimize.execute(query.analyze) - - val correctAnswer = - Union(testRelation.where('a === 1), testRelation2.where('d === 1)).analyze - - comparePlans(optimized, correctAnswer) - } - - test("union: project to each side") { - val query = testUnion.select('b) - - val optimized = Optimize.execute(query.analyze) - - val correctAnswer = - Union(testRelation.select('b), testRelation2.select('e)).analyze - - comparePlans(optimized, correctAnswer) - } -} From 6364735bcc67ecb0e9c7e5076d214ed88e927430 Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Tue, 21 Jul 2015 01:12:51 -0700 Subject: [PATCH 20/32] [SPARK-8875] Remove BlockStoreShuffleFetcher class The shuffle code has gotten increasingly difficult to read as it has evolved, and many classes have evolved significantly since they were originally created. The BlockStoreShuffleFetcher class now serves little purpose other than to make the code more difficult to read; this commit moves its functionality into the ShuffleBlockFetcherIterator class. cc massie JoshRosen (Josh, this PR also removes the Try you pointed out as being confusing / not necessarily useful in a previous comment). Matt, would be helpful to know whether this will interfere in any negative ways with your new shuffle PR (I took a look and it seems like this should still cleanly integrate with your parquet work, but want to double check). Author: Kay Ousterhout Closes #7268 from kayousterhout/SPARK-8875 and squashes the following commits: 2b24a97 [Kay Ousterhout] Fixed DAGSchedulerSuite compile error 98a1831 [Kay Ousterhout] Merge remote-tracking branch 'upstream/master' into SPARK-8875 90f0e89 [Kay Ousterhout] Fixed broken test 14bfcbb [Kay Ousterhout] Last style fix bc69d2b [Kay Ousterhout] Style improvements based on Josh's code review ad3c8d1 [Kay Ousterhout] Better documentation for MapOutputTracker methods 0bc0e59 [Kay Ousterhout] [SPARK-8875] Remove BlockStoreShuffleFetcher class --- .../org/apache/spark/MapOutputTracker.scala | 62 ++++++++++---- .../hash/BlockStoreShuffleFetcher.scala | 85 ------------------- .../shuffle/hash/HashShuffleReader.scala | 19 +++-- .../storage/ShuffleBlockFetcherIterator.scala | 72 ++++++++++------ .../apache/spark/MapOutputTrackerSuite.scala | 28 +++--- .../scala/org/apache/spark/ShuffleSuite.scala | 12 +-- .../spark/scheduler/DAGSchedulerSuite.scala | 32 +++---- .../shuffle/hash/HashShuffleReaderSuite.scala | 14 +-- .../ShuffleBlockFetcherIteratorSuite.scala | 18 ++-- .../apache/spark/util/AkkaUtilsSuite.scala | 22 +++-- 10 files changed, 172 insertions(+), 192 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 862ffe868f58f..92218832d256f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,14 +21,14 @@ import java.io._ import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashMap, HashSet, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.collection.JavaConversions._ import scala.reflect.ClassTag import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ private[spark] sealed trait MapOutputTrackerMessage @@ -124,10 +124,18 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } /** - * Called from executors to get the server URIs and output sizes of the map outputs of - * a given shuffle. + * Called from executors to get the server URIs and output sizes for each shuffle block that + * needs to be read from a given reduce task. + * + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block id, shuffle block size) tuples + * describing the shuffle blocks that are stored at that block manager. */ - def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { + def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int) + : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, reduce $reduceId") + val startTime = System.currentTimeMillis + val statuses = mapStatuses.get(shuffleId).orNull if (statuses == null) { logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them") @@ -167,6 +175,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } } } + logDebug(s"Fetching map output location for shuffle $shuffleId, reduce $reduceId took " + + s"${System.currentTimeMillis - startTime} ms") + if (fetchedStatuses != null) { fetchedStatuses.synchronized { return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) @@ -421,23 +432,38 @@ private[spark] object MapOutputTracker extends Logging { } } - // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If - // any of the statuses is null (indicating a missing location due to a failed mapper), - // throw a FetchFailedException. + /** + * Converts an array of MapStatuses for a given reduce ID to a sequence that, for each block + * manager ID, lists the shuffle block ids and corresponding shuffle block sizes stored at that + * block manager. + * + * If any of the statuses is null (indicating a missing location due to a failed mapper), + * throws a FetchFailedException. + * + * @param shuffleId Identifier for the shuffle + * @param reduceId Identifier for the reduce task + * @param statuses List of map statuses, indexed by map ID. + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block id, shuffle block size) tuples + * describing the shuffle blocks that are stored at that block manager. + */ private def convertMapStatuses( shuffleId: Int, reduceId: Int, - statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { + statuses: Array[MapStatus]): Seq[(BlockManagerId, Seq[(BlockId, Long)])] = { assert (statuses != null) - statuses.map { - status => - if (status == null) { - logError("Missing an output location for shuffle " + shuffleId) - throw new MetadataFetchFailedException( - shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId) - } else { - (status.location, status.getSizeForBlock(reduceId)) - } + val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(BlockId, Long)]] + for ((status, mapId) <- statuses.zipWithIndex) { + if (status == null) { + val errorMessage = s"Missing an output location for shuffle $shuffleId" + logError(errorMessage) + throw new MetadataFetchFailedException(shuffleId, reduceId, errorMessage) + } else { + splitsByAddress.getOrElseUpdate(status.location, ArrayBuffer()) += + ((ShuffleBlockId(shuffleId, mapId, reduceId), status.getSizeForBlock(reduceId))) + } } + + splitsByAddress.toSeq } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala deleted file mode 100644 index 9d8e7e9f03aea..0000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.hash - -import java.io.InputStream - -import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.util.{Failure, Success} - -import org.apache.spark._ -import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator, - ShuffleBlockId} - -private[hash] object BlockStoreShuffleFetcher extends Logging { - def fetchBlockStreams( - shuffleId: Int, - reduceId: Int, - context: TaskContext, - blockManager: BlockManager, - mapOutputTracker: MapOutputTracker) - : Iterator[(BlockId, InputStream)] = - { - logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId)) - - val startTime = System.currentTimeMillis - val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId) - logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format( - shuffleId, reduceId, System.currentTimeMillis - startTime)) - - val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]] - for (((address, size), index) <- statuses.zipWithIndex) { - splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size)) - } - - val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map { - case (address, splits) => - (address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2))) - } - - val blockFetcherItr = new ShuffleBlockFetcherIterator( - context, - blockManager.shuffleClient, - blockManager, - blocksByAddress, - // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility - SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) - - // Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler - blockFetcherItr.map { blockPair => - val blockId = blockPair._1 - val blockOption = blockPair._2 - blockOption match { - case Success(inputStream) => { - (blockId, inputStream) - } - case Failure(e) => { - blockId match { - case ShuffleBlockId(shufId, mapId, _) => - val address = statuses(mapId.toInt)._1 - throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) - case _ => - throw new SparkException( - "Failed to get block " + blockId + ", which is not a shuffle block", e) - } - } - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala index d5c9880659dd3..de79fa56f017b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala @@ -17,10 +17,10 @@ package org.apache.spark.shuffle.hash -import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext} +import org.apache.spark.{InterruptibleIterator, Logging, MapOutputTracker, SparkEnv, TaskContext} import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader} -import org.apache.spark.storage.BlockManager +import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -31,8 +31,8 @@ private[spark] class HashShuffleReader[K, C]( context: TaskContext, blockManager: BlockManager = SparkEnv.get.blockManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) - extends ShuffleReader[K, C] -{ + extends ShuffleReader[K, C] with Logging { + require(endPartition == startPartition + 1, "Hash shuffle currently only supports fetching one partition") @@ -40,11 +40,16 @@ private[spark] class HashShuffleReader[K, C]( /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams( - handle.shuffleId, startPartition, context, blockManager, mapOutputTracker) + val blockFetcherItr = new ShuffleBlockFetcherIterator( + context, + blockManager.shuffleClient, + blockManager, + mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition), + // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility + SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024) // Wrap the streams for compression based on configuration - val wrappedStreams = blockStreams.map { case (blockId, inputStream) => + val wrappedStreams = blockFetcherItr.map { case (blockId, inputStream) => blockManager.wrapForCompression(blockId, inputStream) } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index e49e39679e940..a759ceb96ec1e 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -21,18 +21,19 @@ import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} -import scala.util.{Failure, Try} +import scala.util.control.NonFatal -import org.apache.spark.{Logging, TaskContext} +import org.apache.spark.{Logging, SparkException, TaskContext} import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils /** * An iterator that fetches multiple blocks. For local blocks, it fetches from the local block * manager. For remote blocks, it fetches them using the provided BlockTransferService. * - * This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks + * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks * in a pipelined fashion as they are received. * * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid @@ -53,7 +54,7 @@ final class ShuffleBlockFetcherIterator( blockManager: BlockManager, blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])], maxBytesInFlight: Long) - extends Iterator[(BlockId, Try[InputStream])] with Logging { + extends Iterator[(BlockId, InputStream)] with Logging { import ShuffleBlockFetcherIterator._ @@ -115,7 +116,7 @@ final class ShuffleBlockFetcherIterator( private[storage] def releaseCurrentResultBuffer(): Unit = { // Release the current buffer if necessary currentResult match { - case SuccessFetchResult(_, _, buf) => buf.release() + case SuccessFetchResult(_, _, _, buf) => buf.release() case _ => } currentResult = null @@ -132,7 +133,7 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(_, _, buf) => buf.release() + case SuccessFetchResult(_, _, _, buf) => buf.release() case _ => } } @@ -157,7 +158,7 @@ final class ShuffleBlockFetcherIterator( // Increment the ref count because we need to pass this to a different thread. // This needs to be released after use. buf.retain() - results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf)) + results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf)) shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) } @@ -166,7 +167,7 @@ final class ShuffleBlockFetcherIterator( override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - results.put(new FailureFetchResult(BlockId(blockId), e)) + results.put(new FailureFetchResult(BlockId(blockId), address, e)) } } ) @@ -238,12 +239,12 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() - results.put(new SuccessFetchResult(blockId, 0, buf)) + results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf)) } catch { case e: Exception => // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) - results.put(new FailureFetchResult(blockId, e)) + results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e)) return } } @@ -275,12 +276,14 @@ final class ShuffleBlockFetcherIterator( override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch /** - * Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers + * Fetches the next (BlockId, InputStream). If a task fails, the ManagedBuffers * underlying each InputStream will be freed by the cleanup() method registered with the * TaskCompletionListener. However, callers should close() these InputStreams * as soon as they are no longer needed, in order to release memory as early as possible. + * + * Throws a FetchFailedException if the next block could not be fetched. */ - override def next(): (BlockId, Try[InputStream]) = { + override def next(): (BlockId, InputStream) = { numBlocksProcessed += 1 val startFetchWait = System.currentTimeMillis() currentResult = results.take() @@ -289,7 +292,7 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) result match { - case SuccessFetchResult(_, size, _) => bytesInFlight -= size + case SuccessFetchResult(_, _, size, _) => bytesInFlight -= size case _ => } // Send fetch requests up to maxBytesInFlight @@ -298,19 +301,28 @@ final class ShuffleBlockFetcherIterator( sendRequest(fetchRequests.dequeue()) } - val iteratorTry: Try[InputStream] = result match { - case FailureFetchResult(_, e) => - Failure(e) - case SuccessFetchResult(blockId, _, buf) => - // There is a chance that createInputStream can fail (e.g. fetching a local file that does - // not exist, SPARK-4085). In that case, we should propagate the right exception so - // the scheduler gets a FetchFailedException. - Try(buf.createInputStream()).map { inputStream => - new BufferReleasingInputStream(inputStream, this) + result match { + case FailureFetchResult(blockId, address, e) => + throwFetchFailedException(blockId, address, e) + + case SuccessFetchResult(blockId, address, _, buf) => + try { + (result.blockId, new BufferReleasingInputStream(buf.createInputStream(), this)) + } catch { + case NonFatal(t) => + throwFetchFailedException(blockId, address, t) } } + } - (result.blockId, iteratorTry) + private def throwFetchFailedException(blockId: BlockId, address: BlockManagerId, e: Throwable) = { + blockId match { + case ShuffleBlockId(shufId, mapId, reduceId) => + throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) + case _ => + throw new SparkException( + "Failed to get block " + blockId + ", which is not a shuffle block", e) + } } } @@ -366,16 +378,22 @@ object ShuffleBlockFetcherIterator { */ private[storage] sealed trait FetchResult { val blockId: BlockId + val address: BlockManagerId } /** * Result of a fetch from a remote block successfully. * @param blockId block id + * @param address BlockManager that the block was fetched from. * @param size estimated size of the block, used to calculate bytesInFlight. * Note that this is NOT the exact bytes. * @param buf [[ManagedBuffer]] for the content. */ - private[storage] case class SuccessFetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer) + private[storage] case class SuccessFetchResult( + blockId: BlockId, + address: BlockManagerId, + size: Long, + buf: ManagedBuffer) extends FetchResult { require(buf != null) require(size >= 0) @@ -384,8 +402,12 @@ object ShuffleBlockFetcherIterator { /** * Result of a fetch from a remote block unsuccessfully. * @param blockId block id + * @param address BlockManager that the block was attempted to be fetched from * @param e the failure exception */ - private[storage] case class FailureFetchResult(blockId: BlockId, e: Throwable) + private[storage] case class FailureFetchResult( + blockId: BlockId, + address: BlockManagerId, + e: Throwable) extends FetchResult } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 7a1961137cce5..af4e68950f75a 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -17,13 +17,15 @@ package org.apache.spark +import scala.collection.mutable.ArrayBuffer + import org.mockito.Mockito._ import org.mockito.Matchers.{any, isA} import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} class MapOutputTrackerSuite extends SparkFunSuite { private val conf = new SparkConf @@ -55,9 +57,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { Array(1000L, 10000L))) tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(10000L, 1000L))) - val statuses = tracker.getServerStatuses(10, 0) - assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), - (BlockManagerId("b", "hostB", 1000), size10000))) + val statuses = tracker.getMapSizesByExecutorId(10, 0) + assert(statuses.toSet === + Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))), + (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 1, 0), size10000)))) + .toSet) tracker.stop() rpcEnv.shutdown() } @@ -75,10 +79,10 @@ class MapOutputTrackerSuite extends SparkFunSuite { tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), Array(compressedSize10000, compressedSize1000))) assert(tracker.containsShuffle(10)) - assert(tracker.getServerStatuses(10, 0).nonEmpty) + assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) tracker.unregisterShuffle(10) assert(!tracker.containsShuffle(10)) - assert(tracker.getServerStatuses(10, 0).isEmpty) + assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty) tracker.stop() rpcEnv.shutdown() @@ -104,7 +108,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // The remaining reduce task might try to grab the output despite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the // stage already being aborted. - intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } + intercept[FetchFailedException] { tracker.getMapSizesByExecutorId(10, 1) } tracker.stop() rpcEnv.shutdown() @@ -126,23 +130,23 @@ class MapOutputTrackerSuite extends SparkFunSuite { masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("a", "hostA", 1000), Array(1000L))) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } // failure should be cached - intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) } masterTracker.stop() slaveTracker.stop() diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index c3c2b1ffc1efa..b68102bfb949f 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -66,8 +66,8 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // All blocks must have non-zero size (0 until NUM_BLOCKS).foreach { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - assert(statuses.forall(s => s._2 > 0)) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + assert(statuses.forall(_._2.forall(blockIdSizePair => blockIdSizePair._2 > 0))) } } @@ -105,8 +105,8 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - statuses.map(x => x._2) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + statuses.flatMap(_._2.map(_._2)) } val nonEmptyBlocks = blockSizes.filter(x => x > 0) @@ -130,8 +130,8 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC assert(c.count === 4) val blockSizes = (0 until NUM_BLOCKS).flatMap { id => - val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) - statuses.map(x => x._2) + val statuses = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, id) + statuses.flatMap(_._2.map(_._2)) } val nonEmptyBlocks = blockSizes.filter(x => x > 0) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 86728cb2b62af..3462a82c9cdd3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -483,8 +483,8 @@ class DAGSchedulerSuite complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -510,8 +510,8 @@ class DAGSchedulerSuite // have the 2nd attempt pass complete(taskSets(2), Seq((Success, makeMapStatus("hostA", reduceRdd.partitions.size)))) // we can see both result blocks now - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === - Array("hostA", "hostB")) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === + HashSet("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) assertDataStructuresEmpty() @@ -527,8 +527,8 @@ class DAGSchedulerSuite (Success, makeMapStatus("hostA", reduceRdd.partitions.size)), (Success, makeMapStatus("hostB", reduceRdd.partitions.size)))) // The MapOutputTracker should know about both map output locations. - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === - Array("hostA", "hostB")) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === + HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(CompletionEvent( @@ -577,10 +577,10 @@ class DAGSchedulerSuite (Success, makeMapStatus("hostA", 2)), (Success, makeMapStatus("hostB", 2)))) // The MapOutputTracker should know about both map output locations. - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === - Array("hostA", "hostB")) - assert(mapOutputTracker.getServerStatuses(shuffleId, 1).map(_._1.host) === - Array("hostA", "hostB")) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1.host).toSet === + HashSet("hostA", "hostB")) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 1).map(_._1.host).toSet === + HashSet("hostA", "hostB")) // The first result task fails, with a fetch failure for the output from the first mapper. runEvent(CompletionEvent( @@ -713,8 +713,8 @@ class DAGSchedulerSuite taskSet.tasks(1).epoch = newEpoch runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", reduceRdd.partitions.size), null, createFakeTaskInfo(), null)) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) assertDataStructuresEmpty() @@ -809,8 +809,8 @@ class DAGSchedulerSuite (Success, makeMapStatus("hostB", 1)))) // have hostC complete the resubmitted task complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) complete(taskSets(2), Seq((Success, 42))) assert(results === Map(0 -> 42)) assertDataStructuresEmpty() @@ -981,8 +981,8 @@ class DAGSchedulerSuite submit(reduceRdd, Array(0)) complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostA"))) + assert(mapOutputTracker.getMapSizesByExecutorId(shuffleId, 0).map(_._1).toSet === + HashSet(makeBlockManagerId("hostA"))) // Reducer should run on the same host that map task ran val reduceTaskSet = taskSets(1) diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala index 28ca68698e3dc..6c9cb448e7833 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala @@ -115,11 +115,15 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. val mapOutputTracker = mock(classOf[MapOutputTracker]) - // Test a scenario where all data is local, just to avoid creating a bunch of additional mocks - // for the code to read data over the network. - val statuses: Array[(BlockManagerId, Long)] = - Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size().toLong)) - when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses) + when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId)).thenReturn { + // Test a scenario where all data is local, to avoid creating a bunch of additional mocks + // for the code to read data over the network. + val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => + val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) + (shuffleBlockId, byteOutputStream.size().toLong) + } + Seq((localBlockManagerId, shuffleBlockIdsAndSizes)) + } // Create a mocked shuffle handle to pass into HashShuffleReader. val shuffleHandle = { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 9ced4148d7206..64f3fbdcebed9 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.{SparkFunSuite, TaskContextImpl} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.shuffle.FetchFailedException class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { @@ -106,13 +107,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") val (blockId, inputStream) = iterator.next() - assert(inputStream.isSuccess, - s"iterator should have 5 elements defined but actually has $i elements") // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream - val wrappedInputStream = inputStream.get.asInstanceOf[BufferReleasingInputStream] + val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream] verify(mockBuf, times(0)).release() val delegateAccess = PrivateMethod[InputStream]('delegate) @@ -175,11 +174,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() - iterator.next()._2.get.close() // close() first block's input stream + iterator.next()._2.close() // close() first block's input stream verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release() // Get the 2nd block but do not exhaust the iterator - val subIter = iterator.next()._2.get + val subIter = iterator.next()._2 // Complete the task; then the 2nd block buffer should be exhausted verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release() @@ -239,9 +238,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Continue only after the mock calls onBlockFetchFailure sem.acquire() - // The first block should be defined, and the last two are not defined (due to failure) - assert(iterator.next()._2.isSuccess) - assert(iterator.next()._2.isFailure) - assert(iterator.next()._2.isFailure) + // The first block should be returned without an exception, and the last two should throw + // FetchFailedExceptions (due to failure) + iterator.next() + intercept[FetchFailedException] { iterator.next() } + intercept[FetchFailedException] { iterator.next() } } } diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 6c40685484ed4..61601016e005e 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.util +import scala.collection.mutable.ArrayBuffer + import java.util.concurrent.TimeoutException import akka.actor.ActorNotFound @@ -24,7 +26,7 @@ import akka.actor.ActorNotFound import org.apache.spark._ import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.MapStatus -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} import org.apache.spark.SSLSampleConfigs._ @@ -107,8 +109,9 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst slaveTracker.updateEpoch(masterTracker.getEpoch) // this should succeed since security off - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === + Seq((BlockManagerId("a", "hostA", 1000), + ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) rpcEnv.shutdown() slaveRpcEnv.shutdown() @@ -153,8 +156,9 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst slaveTracker.updateEpoch(masterTracker.getEpoch) // this should succeed since security on and passwords match - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + Seq((BlockManagerId("a", "hostA", 1000), + ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) rpcEnv.shutdown() slaveRpcEnv.shutdown() @@ -232,8 +236,8 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst slaveTracker.updateEpoch(masterTracker.getEpoch) // this should succeed since security off - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) rpcEnv.shutdown() slaveRpcEnv.shutdown() @@ -278,8 +282,8 @@ class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSyst masterTracker.incrementEpoch() slaveTracker.updateEpoch(masterTracker.getEpoch) - assert(slaveTracker.getServerStatuses(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), size1000))) + assert(slaveTracker.getMapSizesByExecutorId(10, 0) === + Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) rpcEnv.shutdown() slaveRpcEnv.shutdown() From f5b6dc5e3e7e3b586096b71164f052318b840e8a Mon Sep 17 00:00:00 2001 From: Michael Allman Date: Tue, 21 Jul 2015 11:14:31 +0100 Subject: [PATCH 21/32] [SPARK-8401] [BUILD] Scala version switching build enhancements These commits address a few minor issues in the Scala cross-version support in the build: 1. Correct two missing `${scala.binary.version}` pom file substitutions. 2. Don't update `scala.binary.version` in parent POM. This property is set through profiles. 3. Update the source of the generated scaladocs in `docs/_plugins/copy_api_dirs.rb`. 4. Factor common code out of `dev/change-version-to-*.sh` and add some validation. We also test `sed` to see if it's GNU sed and try `gsed` as an alternative if not. This prevents the script from running with a non-GNU sed. This is my original work and I license this work to the Spark project under the Apache License. Author: Michael Allman Closes #6832 from mallman/scala-versions and squashes the following commits: cde2f17 [Michael Allman] Delete dev/change-version-to-*.sh, replacing them with single dev/change-scala-version.sh script that takes a version as argument 02296f2 [Michael Allman] Make the scala version change scripts cross-platform by restricting ourselves to POSIX sed syntax instead of looking for GNU sed ad9b40a [Michael Allman] Factor change-scala-version.sh out of change-version-to-*.sh, adding command line argument validation and testing for GNU sed bdd20bf [Michael Allman] Update source of scaladocs when changing Scala version 475088e [Michael Allman] Replace jackson-module-scala_2.10 with jackson-module-scala_${scala.binary.version} --- core/pom.xml | 2 +- dev/change-scala-version.sh | 66 ++++++++++++++++++++++++++++ dev/change-version-to-2.10.sh | 26 ----------- dev/change-version-to-2.11.sh | 26 ----------- dev/create-release/create-release.sh | 6 +-- docs/building-spark.md | 2 +- pom.xml | 2 +- 7 files changed, 72 insertions(+), 58 deletions(-) create mode 100755 dev/change-scala-version.sh delete mode 100755 dev/change-version-to-2.10.sh delete mode 100755 dev/change-version-to-2.11.sh diff --git a/core/pom.xml b/core/pom.xml index 73f7a75cab9d3..95f36eb348698 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -261,7 +261,7 @@ com.fasterxml.jackson.module - jackson-module-scala_2.10 + jackson-module-scala_${scala.binary.version} org.apache.derby diff --git a/dev/change-scala-version.sh b/dev/change-scala-version.sh new file mode 100755 index 0000000000000..b81c00c9d6d9d --- /dev/null +++ b/dev/change-scala-version.sh @@ -0,0 +1,66 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -e + +usage() { + echo "Usage: $(basename $0) " 1>&2 + exit 1 +} + +if [ $# -ne 1 ]; then + usage +fi + +TO_VERSION=$1 + +VALID_VERSIONS=( 2.10 2.11 ) + +check_scala_version() { + for i in ${VALID_VERSIONS[*]}; do [ $i = "$1" ] && return 0; done + echo "Invalid Scala version: $1. Valid versions: ${VALID_VERSIONS[*]}" 1>&2 + exit 1 +} + +check_scala_version "$TO_VERSION" + +if [ $TO_VERSION = "2.11" ]; then + FROM_VERSION="2.10" +else + FROM_VERSION="2.11" +fi + +sed_i() { + sed -e "$1" "$2" > "$2.tmp" && mv "$2.tmp" "$2" +} + +export -f sed_i + +BASEDIR=$(dirname $0)/.. +find "$BASEDIR" -name 'pom.xml' -not -path '*target*' -print \ + -exec bash -c "sed_i 's/\(artifactId.*\)_'$FROM_VERSION'/\1_'$TO_VERSION'/g' {}" \; + +# Also update in parent POM +# Match any scala binary version to ensure idempotency +sed_i '1,/[0-9]*\.[0-9]*[0-9]*\.[0-9]*'$TO_VERSION' in parent POM -sed -i -e '0,/2.112.10 in parent POM -sed -i -e '0,/2.102.11 com.fasterxml.jackson.module - jackson-module-scala_2.10 + jackson-module-scala_${scala.binary.version} ${fasterxml.jackson.version} From be5c5d3741256697cc76938a8ed6f609eb2d4b11 Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Tue, 21 Jul 2015 08:25:50 -0700 Subject: [PATCH 22/32] [SPARK-9081] [SPARK-9168] [SQL] nanvl & dropna/fillna supporting nan as well JIRA: https://issues.apache.org/jira/browse/SPARK-9081 https://issues.apache.org/jira/browse/SPARK-9168 This PR target at two modifications: 1. Change `isNaN` to return `false` on `null` input 2. Make `dropna` and `fillna` to fill/drop NaN values as well 3. Implement `nanvl` Author: Yijie Shen Closes #7523 from yjshen/fillna_dropna and squashes the following commits: f0a51db [Yijie Shen] make coalesce untouched and implement nanvl 1d3e35f [Yijie Shen] make Coalesce aware of NaN in order to support fillna 2760cbc [Yijie Shen] change isNaN(null) to false as well as implement dropna --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../catalyst/expressions/nullFunctions.scala | 104 ++++++++++++++---- .../sql/catalyst/expressions/predicates.scala | 5 +- .../expressions/NullFunctionsSuite.scala | 39 ++++++- .../scala/org/apache/spark/sql/Column.scala | 2 +- .../spark/sql/DataFrameNaFunctions.scala | 52 +++++---- .../org/apache/spark/sql/functions.scala | 13 ++- .../spark/sql/ColumnExpressionSuite.scala | 25 ++++- .../spark/sql/DataFrameNaFunctionsSuite.scala | 69 ++++++------ 9 files changed, 222 insertions(+), 88 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 13523720daff0..e3d8d2adf2135 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -89,6 +89,7 @@ object FunctionRegistry { expression[CreateStruct]("struct"), expression[CreateNamedStruct]("named_struct"), expression[Sqrt]("sqrt"), + expression[NaNvl]("nanvl"), // math functions expression[Acos]("acos"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala index 98c67084642e3..287718fab7f0d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullFunctions.scala @@ -83,7 +83,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { /** - * Evaluates to `true` if it's NaN or null + * Evaluates to `true` iff it's NaN. */ case class IsNaN(child: Expression) extends UnaryExpression with Predicate with ImplicitCastInputTypes { @@ -95,7 +95,7 @@ case class IsNaN(child: Expression) extends UnaryExpression override def eval(input: InternalRow): Any = { val value = child.eval(input) if (value == null) { - true + false } else { child.dataType match { case DoubleType => value.asInstanceOf[Double].isNaN @@ -107,26 +107,65 @@ case class IsNaN(child: Expression) extends UnaryExpression override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val eval = child.gen(ctx) child.dataType match { - case FloatType => + case DoubleType | FloatType => s""" ${eval.code} boolean ${ev.isNull} = false; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (${eval.isNull}) { - ${ev.primitive} = true; - } else { - ${ev.primitive} = Float.isNaN(${eval.primitive}); - } + ${ev.primitive} = !${eval.isNull} && Double.isNaN(${eval.primitive}); """ - case DoubleType => + } + } +} + +/** + * An Expression evaluates to `left` iff it's not NaN, or evaluates to `right` otherwise. + * This Expression is useful for mapping NaN values to null. + */ +case class NaNvl(left: Expression, right: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = left.dataType + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(DoubleType, FloatType), TypeCollection(DoubleType, FloatType)) + + override def eval(input: InternalRow): Any = { + val value = left.eval(input) + if (value == null) { + null + } else { + left.dataType match { + case DoubleType => + if (!value.asInstanceOf[Double].isNaN) value else right.eval(input) + case FloatType => + if (!value.asInstanceOf[Float].isNaN) value else right.eval(input) + } + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val leftGen = left.gen(ctx) + val rightGen = right.gen(ctx) + left.dataType match { + case DoubleType | FloatType => s""" - ${eval.code} + ${leftGen.code} boolean ${ev.isNull} = false; ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (${eval.isNull}) { - ${ev.primitive} = true; + if (${leftGen.isNull}) { + ${ev.isNull} = true; } else { - ${ev.primitive} = Double.isNaN(${eval.primitive}); + if (!Double.isNaN(${leftGen.primitive})) { + ${ev.primitive} = ${leftGen.primitive}; + } else { + ${rightGen.code} + if (${rightGen.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.primitive} = ${rightGen.primitive}; + } + } } """ } @@ -186,8 +225,15 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate var numNonNulls = 0 var i = 0 while (i < childrenArray.length && numNonNulls < n) { - if (childrenArray(i).eval(input) != null) { - numNonNulls += 1 + val evalC = childrenArray(i).eval(input) + if (evalC != null) { + childrenArray(i).dataType match { + case DoubleType => + if (!evalC.asInstanceOf[Double].isNaN) numNonNulls += 1 + case FloatType => + if (!evalC.asInstanceOf[Float].isNaN) numNonNulls += 1 + case _ => numNonNulls += 1 + } } i += 1 } @@ -198,14 +244,26 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate val nonnull = ctx.freshName("nonnull") val code = children.map { e => val eval = e.gen(ctx) - s""" - if ($nonnull < $n) { - ${eval.code} - if (!${eval.isNull}) { - $nonnull += 1; - } - } - """ + e.dataType match { + case DoubleType | FloatType => + s""" + if ($nonnull < $n) { + ${eval.code} + if (!${eval.isNull} && !Double.isNaN(${eval.primitive})) { + $nonnull += 1; + } + } + """ + case _ => + s""" + if ($nonnull < $n) { + ${eval.code} + if (!${eval.isNull}) { + $nonnull += 1; + } + } + """ + } }.mkString("\n") s""" int $nonnull = 0; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index a53ec31ee6a4b..3f1bd2a925fe7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.TypeUtils -import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -121,7 +121,6 @@ case class InSet(child: Expression, hset: Set[Any]) } } - case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { override def inputType: AbstractDataType = BooleanType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala index 765cc7a969b5d..0728f6695c39d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullFunctionsSuite.scala @@ -49,12 +49,22 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(IsNaN(Literal(Double.NaN)), true) checkEvaluation(IsNaN(Literal(Float.NaN)), true) checkEvaluation(IsNaN(Literal(math.log(-3))), true) - checkEvaluation(IsNaN(Literal.create(null, DoubleType)), true) + checkEvaluation(IsNaN(Literal.create(null, DoubleType)), false) checkEvaluation(IsNaN(Literal(Double.PositiveInfinity)), false) checkEvaluation(IsNaN(Literal(Float.MaxValue)), false) checkEvaluation(IsNaN(Literal(5.5f)), false) } + test("nanvl") { + checkEvaluation(NaNvl(Literal(5.0), Literal.create(null, DoubleType)), 5.0) + checkEvaluation(NaNvl(Literal.create(null, DoubleType), Literal(5.0)), null) + checkEvaluation(NaNvl(Literal.create(null, DoubleType), Literal(Double.NaN)), null) + checkEvaluation(NaNvl(Literal(Double.NaN), Literal(5.0)), 5.0) + checkEvaluation(NaNvl(Literal(Double.NaN), Literal.create(null, DoubleType)), null) + assert(NaNvl(Literal(Double.NaN), Literal(Double.NaN)). + eval(EmptyRow).asInstanceOf[Double].isNaN) + } + test("coalesce") { testAllTypes { (value: Any, tpe: DataType) => val lit = Literal.create(value, tpe) @@ -66,4 +76,31 @@ class NullFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Coalesce(Seq(nullLit, nullLit, lit)), value) } } + + test("AtLeastNNonNulls") { + val mix = Seq(Literal("x"), + Literal.create(null, StringType), + Literal.create(null, DoubleType), + Literal(Double.NaN), + Literal(5f)) + + val nanOnly = Seq(Literal("x"), + Literal(10.0), + Literal(Float.NaN), + Literal(math.log(-2)), + Literal(Double.MaxValue)) + + val nullOnly = Seq(Literal("x"), + Literal.create(null, DoubleType), + Literal.create(null, DecimalType.Unlimited), + Literal(Float.MaxValue), + Literal(false)) + + checkEvaluation(AtLeastNNonNulls(2, mix), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, mix), false, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, nanOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(4, nanOnly), false, EmptyRow) + checkEvaluation(AtLeastNNonNulls(3, nullOnly), true, EmptyRow) + checkEvaluation(AtLeastNNonNulls(4, nullOnly), false, EmptyRow) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 221cd04c6d288..6e2a6525bf17e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -401,7 +401,7 @@ class Column(protected[sql] val expr: Expression) extends Logging { } /** - * True if the current expression is NaN or null + * True if the current expression is NaN. * * @group expr_ops * @since 1.5.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala index 8681a56c82f1e..a4fd4cf3b330b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameNaFunctions.scala @@ -37,24 +37,24 @@ import org.apache.spark.sql.types._ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** - * Returns a new [[DataFrame]] that drops rows containing any null values. + * Returns a new [[DataFrame]] that drops rows containing any null or NaN values. * * @since 1.3.1 */ def drop(): DataFrame = drop("any", df.columns) /** - * Returns a new [[DataFrame]] that drops rows containing null values. + * Returns a new [[DataFrame]] that drops rows containing null or NaN values. * - * If `how` is "any", then drop rows containing any null values. - * If `how` is "all", then drop rows only if every column is null for that row. + * If `how` is "any", then drop rows containing any null or NaN values. + * If `how` is "all", then drop rows only if every column is null or NaN for that row. * * @since 1.3.1 */ def drop(how: String): DataFrame = drop(how, df.columns) /** - * Returns a new [[DataFrame]] that drops rows containing any null values + * Returns a new [[DataFrame]] that drops rows containing any null or NaN values * in the specified columns. * * @since 1.3.1 @@ -62,7 +62,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { def drop(cols: Array[String]): DataFrame = drop(cols.toSeq) /** - * (Scala-specific) Returns a new [[DataFrame ]] that drops rows containing any null values + * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing any null or NaN values * in the specified columns. * * @since 1.3.1 @@ -70,22 +70,22 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { def drop(cols: Seq[String]): DataFrame = drop(cols.size, cols) /** - * Returns a new [[DataFrame]] that drops rows containing null values + * Returns a new [[DataFrame]] that drops rows containing null or NaN values * in the specified columns. * - * If `how` is "any", then drop rows containing any null values in the specified columns. - * If `how` is "all", then drop rows only if every specified column is null for that row. + * If `how` is "any", then drop rows containing any null or NaN values in the specified columns. + * If `how` is "all", then drop rows only if every specified column is null or NaN for that row. * * @since 1.3.1 */ def drop(how: String, cols: Array[String]): DataFrame = drop(how, cols.toSeq) /** - * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing null values + * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing null or NaN values * in the specified columns. * - * If `how` is "any", then drop rows containing any null values in the specified columns. - * If `how` is "all", then drop rows only if every specified column is null for that row. + * If `how` is "any", then drop rows containing any null or NaN values in the specified columns. + * If `how` is "all", then drop rows only if every specified column is null or NaN for that row. * * @since 1.3.1 */ @@ -98,15 +98,16 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { } /** - * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null values. + * Returns a new [[DataFrame]] that drops rows containing + * less than `minNonNulls` non-null and non-NaN values. * * @since 1.3.1 */ def drop(minNonNulls: Int): DataFrame = drop(minNonNulls, df.columns) /** - * Returns a new [[DataFrame]] that drops rows containing less than `minNonNulls` non-null - * values in the specified columns. + * Returns a new [[DataFrame]] that drops rows containing + * less than `minNonNulls` non-null and non-NaN values in the specified columns. * * @since 1.3.1 */ @@ -114,32 +115,33 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { /** * (Scala-specific) Returns a new [[DataFrame]] that drops rows containing less than - * `minNonNulls` non-null values in the specified columns. + * `minNonNulls` non-null and non-NaN values in the specified columns. * * @since 1.3.1 */ def drop(minNonNulls: Int, cols: Seq[String]): DataFrame = { - // Filtering condition -- only keep the row if it has at least `minNonNulls` non-null values. + // Filtering condition: + // only keep the row if it has at least `minNonNulls` non-null and non-NaN values. val predicate = AtLeastNNonNulls(minNonNulls, cols.map(name => df.resolve(name))) df.filter(Column(predicate)) } /** - * Returns a new [[DataFrame]] that replaces null values in numeric columns with `value`. + * Returns a new [[DataFrame]] that replaces null or NaN values in numeric columns with `value`. * * @since 1.3.1 */ def fill(value: Double): DataFrame = fill(value, df.columns) /** - * Returns a new [[DataFrame ]] that replaces null values in string columns with `value`. + * Returns a new [[DataFrame]] that replaces null values in string columns with `value`. * * @since 1.3.1 */ def fill(value: String): DataFrame = fill(value, df.columns) /** - * Returns a new [[DataFrame]] that replaces null values in specified numeric columns. + * Returns a new [[DataFrame]] that replaces null or NaN values in specified numeric columns. * If a specified column is not a numeric column, it is ignored. * * @since 1.3.1 @@ -147,7 +149,7 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { def fill(value: Double, cols: Array[String]): DataFrame = fill(value, cols.toSeq) /** - * (Scala-specific) Returns a new [[DataFrame]] that replaces null values in specified + * (Scala-specific) Returns a new [[DataFrame]] that replaces null or NaN values in specified * numeric columns. If a specified column is not a numeric column, it is ignored. * * @since 1.3.1 @@ -391,7 +393,13 @@ final class DataFrameNaFunctions private[sql](df: DataFrame) { * Returns a [[Column]] expression that replaces null value in `col` with `replacement`. */ private def fillCol[T](col: StructField, replacement: T): Column = { - coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name) + col.dataType match { + case DoubleType | FloatType => + coalesce(nanvl(df.col("`" + col.name + "`"), lit(null)), + lit(replacement).cast(col.dataType)).as(col.name) + case _ => + coalesce(df.col("`" + col.name + "`"), lit(replacement).cast(col.dataType)).as(col.name) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 60b089180c876..d94d7335828c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -595,7 +595,7 @@ object functions { } /** - * Returns the first column that is not null. + * Returns the first column that is not null and not NaN. * {{{ * df.select(coalesce(df("a"), df("b"))) * }}} @@ -612,7 +612,7 @@ object functions { def explode(e: Column): Column = Explode(e.expr) /** - * Return true if the column is NaN or null + * Return true iff the column is NaN. * * @group normal_funcs * @since 1.5.0 @@ -636,6 +636,15 @@ object functions { */ def monotonicallyIncreasingId(): Column = execution.expressions.MonotonicallyIncreasingID() + /** + * Return an alternative value `r` if `l` is NaN. + * This function is useful for mapping NaN values to null. + * + * @group normal_funcs + * @since 1.5.0 + */ + def nanvl(l: Column, r: Column): Column = NaNvl(l.expr, r.expr) + /** * Unary minus, i.e. negate the expression. * {{{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 6bd5804196853..1f9f7118c3f04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -211,15 +211,34 @@ class ColumnExpressionSuite extends QueryTest { checkAnswer( testData.select($"a".isNaN, $"b".isNaN), - Row(true, true) :: Row(true, true) :: Row(true, true) :: Row(false, false) :: Nil) + Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) checkAnswer( testData.select(isNaN($"a"), isNaN($"b")), - Row(true, true) :: Row(true, true) :: Row(true, true) :: Row(false, false) :: Nil) + Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) checkAnswer( ctx.sql("select isnan(15), isnan('invalid')"), - Row(false, true)) + Row(false, false)) + } + + test("nanvl") { + val testData = ctx.createDataFrame(ctx.sparkContext.parallelize( + Row(null, 3.0, Double.NaN, Double.PositiveInfinity) :: Nil), + StructType(Seq(StructField("a", DoubleType), StructField("b", DoubleType), + StructField("c", DoubleType), StructField("d", DoubleType)))) + + checkAnswer( + testData.select( + nanvl($"a", lit(5)), nanvl($"b", lit(10)), + nanvl($"c", lit(null).cast(DoubleType)), nanvl($"d", lit(10))), + Row(null, 3.0, null, Double.PositiveInfinity) + ) + testData.registerTempTable("t") + checkAnswer( + ctx.sql("select nanvl(a, 5), nanvl(b, 10), nanvl(c, null), nanvl(d, 10) from t"), + Row(null, 3.0, null, Double.PositiveInfinity) + ) } test("===") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 495701d4f616c..dbe3b44ee2c79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -30,8 +30,10 @@ class DataFrameNaFunctionsSuite extends QueryTest { ("Bob", 16, 176.5), ("Alice", null, 164.3), ("David", 60, null), + ("Nina", 25, Double.NaN), ("Amy", null, null), - (null, null, null)).toDF("name", "age", "height") + (null, null, null) + ).toDF("name", "age", "height") } test("drop") { @@ -39,12 +41,12 @@ class DataFrameNaFunctionsSuite extends QueryTest { val rows = input.collect() checkAnswer( - input.na.drop("name" :: Nil), - rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil) + input.na.drop("name" :: Nil).select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: Row("Nina") :: Row("Amy") :: Nil) checkAnswer( - input.na.drop("age" :: Nil), - rows(0) :: rows(2) :: Nil) + input.na.drop("age" :: Nil).select("name"), + Row("Bob") :: Row("David") :: Row("Nina") :: Nil) checkAnswer( input.na.drop("age" :: "height" :: Nil), @@ -67,8 +69,8 @@ class DataFrameNaFunctionsSuite extends QueryTest { val rows = input.collect() checkAnswer( - input.na.drop("all"), - rows(0) :: rows(1) :: rows(2) :: rows(3) :: Nil) + input.na.drop("all").select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: Row("Nina") :: Row("Amy") :: Nil) checkAnswer( input.na.drop("any"), @@ -79,8 +81,8 @@ class DataFrameNaFunctionsSuite extends QueryTest { rows(0) :: Nil) checkAnswer( - input.na.drop("all", Seq("age", "height")), - rows(0) :: rows(1) :: rows(2) :: Nil) + input.na.drop("all", Seq("age", "height")).select("name"), + Row("Bob") :: Row("Alice") :: Row("David") :: Row("Nina") :: Nil) } test("drop with threshold") { @@ -108,6 +110,7 @@ class DataFrameNaFunctionsSuite extends QueryTest { Row("Bob", 16, 176.5) :: Row("Alice", 50, 164.3) :: Row("David", 60, 50.6) :: + Row("Nina", 25, 50.6) :: Row("Amy", 50, 50.6) :: Row(null, 50, 50.6) :: Nil) @@ -117,17 +120,19 @@ class DataFrameNaFunctionsSuite extends QueryTest { // string checkAnswer( input.na.fill("unknown").select("name"), - Row("Bob") :: Row("Alice") :: Row("David") :: Row("Amy") :: Row("unknown") :: Nil) + Row("Bob") :: Row("Alice") :: Row("David") :: + Row("Nina") :: Row("Amy") :: Row("unknown") :: Nil) assert(input.na.fill("unknown").columns.toSeq === input.columns.toSeq) // fill double with subset columns checkAnswer( - input.na.fill(50.6, "age" :: Nil), - Row("Bob", 16, 176.5) :: - Row("Alice", 50, 164.3) :: - Row("David", 60, null) :: - Row("Amy", 50, null) :: - Row(null, 50, null) :: Nil) + input.na.fill(50.6, "age" :: Nil).select("name", "age"), + Row("Bob", 16) :: + Row("Alice", 50) :: + Row("David", 60) :: + Row("Nina", 25) :: + Row("Amy", 50) :: + Row(null, 50) :: Nil) // fill string with subset columns checkAnswer( @@ -164,29 +169,27 @@ class DataFrameNaFunctionsSuite extends QueryTest { 16 -> 61, 60 -> 6, 164.3 -> 461.3 // Alice is really tall - )) + )).collect() - checkAnswer( - out, - Row("Bob", 61, 176.5) :: - Row("Alice", null, 461.3) :: - Row("David", 6, null) :: - Row("Amy", null, null) :: - Row(null, null, null) :: Nil) + assert(out(0) === Row("Bob", 61, 176.5)) + assert(out(1) === Row("Alice", null, 461.3)) + assert(out(2) === Row("David", 6, null)) + assert(out(3).get(2).asInstanceOf[Double].isNaN) + assert(out(4) === Row("Amy", null, null)) + assert(out(5) === Row(null, null, null)) // Replace only the age column val out1 = input.na.replace("age", Map( 16 -> 61, 60 -> 6, 164.3 -> 461.3 // Alice is really tall - )) - - checkAnswer( - out1, - Row("Bob", 61, 176.5) :: - Row("Alice", null, 164.3) :: - Row("David", 6, null) :: - Row("Amy", null, null) :: - Row(null, null, null) :: Nil) + )).collect() + + assert(out1(0) === Row("Bob", 61, 176.5)) + assert(out1(1) === Row("Alice", null, 164.3)) + assert(out1(2) === Row("David", 6, null)) + assert(out1(3).get(2).asInstanceOf[Double].isNaN) + assert(out1(4) === Row("Amy", null, null)) + assert(out1(5) === Row(null, null, null)) } } From df4ddb3120be28df381c11a36312620e58034b93 Mon Sep 17 00:00:00 2001 From: petz2000 Date: Tue, 21 Jul 2015 08:50:43 -0700 Subject: [PATCH 23/32] [SPARK-8915] [DOCUMENTATION, MLLIB] Added @since tags to mllib.classification Created since tags for methods in mllib.classification Author: petz2000 Closes #7371 from petz2000/add_since_mllib.classification and squashes the following commits: 39fe291 [petz2000] Removed whitespace in block comment c9b1e03 [petz2000] Removed @since tags again from protected and private methods cd759b6 [petz2000] Added @since tags to methods --- .../classification/ClassificationModel.scala | 3 +++ .../classification/LogisticRegression.scala | 17 +++++++++++++++++ .../spark/mllib/classification/NaiveBayes.scala | 3 +++ .../apache/spark/mllib/classification/SVM.scala | 16 ++++++++++++++++ 4 files changed, 39 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala index 35a0db76f3a8c..ba73024e3c04d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala @@ -36,6 +36,7 @@ trait ClassificationModel extends Serializable { * * @param testData RDD representing data points to be predicted * @return an RDD[Double] where each entry contains the corresponding prediction + * @since 0.8.0 */ def predict(testData: RDD[Vector]): RDD[Double] @@ -44,6 +45,7 @@ trait ClassificationModel extends Serializable { * * @param testData array representing a single data point * @return predicted category from the trained model + * @since 0.8.0 */ def predict(testData: Vector): Double @@ -51,6 +53,7 @@ trait ClassificationModel extends Serializable { * Predict values for examples stored in a JavaRDD. * @param testData JavaRDD representing data points to be predicted * @return a JavaRDD[java.lang.Double] where each entry contains the corresponding prediction + * @since 0.8.0 */ def predict(testData: JavaRDD[Vector]): JavaRDD[java.lang.Double] = predict(testData.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 2df4d21e8cd55..268642ac6a2f6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -85,6 +85,7 @@ class LogisticRegressionModel ( * in Binary Logistic Regression. An example with prediction score greater than or equal to * this threshold is identified as an positive, and negative otherwise. The default value is 0.5. * It is only used for binary classification. + * @since 1.0.0 */ @Experimental def setThreshold(threshold: Double): this.type = { @@ -96,6 +97,7 @@ class LogisticRegressionModel ( * :: Experimental :: * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. * It is only used for binary classification. + * @since 1.3.0 */ @Experimental def getThreshold: Option[Double] = threshold @@ -104,6 +106,7 @@ class LogisticRegressionModel ( * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. * It is only used for binary classification. + * @since 1.0.0 */ @Experimental def clearThreshold(): this.type = { @@ -155,6 +158,9 @@ class LogisticRegressionModel ( } } + /** + * @since 1.3.0 + */ override def save(sc: SparkContext, path: String): Unit = { GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, numFeatures, numClasses, weights, intercept, threshold) @@ -162,6 +168,9 @@ class LogisticRegressionModel ( override protected def formatVersion: String = "1.0" + /** + * @since 1.4.0 + */ override def toString: String = { s"${super.toString}, numClasses = ${numClasses}, threshold = ${threshold.getOrElse("None")}" } @@ -169,6 +178,9 @@ class LogisticRegressionModel ( object LogisticRegressionModel extends Loader[LogisticRegressionModel] { + /** + * @since 1.3.0 + */ override def load(sc: SparkContext, path: String): LogisticRegressionModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) // Hard-code class name string in case it changes in the future @@ -249,6 +261,7 @@ object LogisticRegressionWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * @since 1.0.0 */ def train( input: RDD[LabeledPoint], @@ -271,6 +284,7 @@ object LogisticRegressionWithSGD { * @param stepSize Step size to be used for each iteration of gradient descent. * @param miniBatchFraction Fraction of data to be used per iteration. + * @since 1.0.0 */ def train( input: RDD[LabeledPoint], @@ -292,6 +306,7 @@ object LogisticRegressionWithSGD { * @param numIterations Number of iterations of gradient descent to run. * @return a LogisticRegressionModel which has the weights and offset from training. + * @since 1.0.0 */ def train( input: RDD[LabeledPoint], @@ -309,6 +324,7 @@ object LogisticRegressionWithSGD { * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a LogisticRegressionModel which has the weights and offset from training. + * @since 1.0.0 */ def train( input: RDD[LabeledPoint], @@ -345,6 +361,7 @@ class LogisticRegressionWithLBFGS * Set the number of possible outcomes for k classes classification problem in * Multinomial Logistic Regression. * By default, it is binary logistic regression so k will be set to 2. + * @since 1.3.0 */ @Experimental def setNumClasses(numClasses: Int): this.type = { diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 8cf4e15efe7a7..2df91c09421e9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -444,6 +444,7 @@ object NaiveBayes { * * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency * vector or a count vector. + * @since 0.9.0 */ def train(input: RDD[LabeledPoint]): NaiveBayesModel = { new NaiveBayes().run(input) @@ -459,6 +460,7 @@ object NaiveBayes { * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency * vector or a count vector. * @param lambda The smoothing parameter + * @since 0.9.0 */ def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = { new NaiveBayes(lambda, Multinomial).run(input) @@ -481,6 +483,7 @@ object NaiveBayes { * * @param modelType The type of NB model to fit from the enumeration NaiveBayesModels, can be * multinomial or bernoulli + * @since 0.9.0 */ def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel = { require(supportedModelTypes.contains(modelType), diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index 348485560713e..5b54feeb10467 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -46,6 +46,7 @@ class SVMModel ( * Sets the threshold that separates positive predictions from negative predictions. An example * with prediction score greater than or equal to this threshold is identified as an positive, * and negative otherwise. The default value is 0.0. + * @since 1.3.0 */ @Experimental def setThreshold(threshold: Double): this.type = { @@ -56,6 +57,7 @@ class SVMModel ( /** * :: Experimental :: * Returns the threshold (if any) used for converting raw prediction scores into 0/1 predictions. + * @since 1.3.0 */ @Experimental def getThreshold: Option[Double] = threshold @@ -63,6 +65,7 @@ class SVMModel ( /** * :: Experimental :: * Clears the threshold so that `predict` will output raw prediction scores. + * @since 1.0.0 */ @Experimental def clearThreshold(): this.type = { @@ -81,6 +84,9 @@ class SVMModel ( } } + /** + * @since 1.3.0 + */ override def save(sc: SparkContext, path: String): Unit = { GLMClassificationModel.SaveLoadV1_0.save(sc, path, this.getClass.getName, numFeatures = weights.size, numClasses = 2, weights, intercept, threshold) @@ -88,6 +94,9 @@ class SVMModel ( override protected def formatVersion: String = "1.0" + /** + * @since 1.4.0 + */ override def toString: String = { s"${super.toString}, numClasses = 2, threshold = ${threshold.getOrElse("None")}" } @@ -95,6 +104,9 @@ class SVMModel ( object SVMModel extends Loader[SVMModel] { + /** + * @since 1.3.0 + */ override def load(sc: SparkContext, path: String): SVMModel = { val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path) // Hard-code class name string in case it changes in the future @@ -173,6 +185,7 @@ object SVMWithSGD { * @param miniBatchFraction Fraction of data to be used per iteration. * @param initialWeights Initial set of weights to be used. Array should be equal in size to * the number of features in the data. + * @since 0.8.0 */ def train( input: RDD[LabeledPoint], @@ -196,6 +209,7 @@ object SVMWithSGD { * @param stepSize Step size to be used for each iteration of gradient descent. * @param regParam Regularization parameter. * @param miniBatchFraction Fraction of data to be used per iteration. + * @since 0.8.0 */ def train( input: RDD[LabeledPoint], @@ -217,6 +231,7 @@ object SVMWithSGD { * @param regParam Regularization parameter. * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. + * @since 0.8.0 */ def train( input: RDD[LabeledPoint], @@ -235,6 +250,7 @@ object SVMWithSGD { * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. * @return a SVMModel which has the weights and offset from training. + * @since 0.8.0 */ def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = { train(input, numIterations, 1.0, 0.01, 1.0) From 6592a6058eee6a27a5c91281ca19076284d62483 Mon Sep 17 00:00:00 2001 From: Grace Date: Tue, 21 Jul 2015 11:35:49 -0500 Subject: [PATCH 24/32] [SPARK-9193] Avoid assigning tasks to "lost" executor(s) Now, when some executors are killed by dynamic-allocation, it leads to some mis-assignment onto lost executors sometimes. Such kind of mis-assignment causes task failure(s) or even job failure if it repeats that errors for 4 times. The root cause is that ***killExecutors*** doesn't remove those executors under killing ASAP. It depends on the ***OnDisassociated*** event to refresh the active working list later. The delay time really depends on your cluster status (from several milliseconds to sub-minute). When new tasks to be scheduled during that period of time, it will be assigned to those "active" but "under killing" executors. Then the tasks will be failed due to "executor lost". The better way is to exclude those executors under killing in the makeOffers(). Then all those tasks won't be allocated onto those executors "to be lost" any more. Author: Grace Closes #7528 from GraceH/AssignToLostExecutor and squashes the following commits: ecc1da6 [Grace] scala style fix 6e2ed96 [Grace] Re-word makeOffers by more readable lines b5546ce [Grace] Add comments about the fix 30a9ad0 [Grace] Avoid assigning tasks to lost executors --- .../cluster/CoarseGrainedSchedulerBackend.scala | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index f14c603ac6891..c65b3e517773e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -169,9 +169,12 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Make fake resource offers on all executors private def makeOffers() { - launchTasks(scheduler.resourceOffers(executorDataMap.map { case (id, executorData) => + // Filter out executors under killing + val activeExecutors = executorDataMap.filterKeys(!executorsPendingToRemove.contains(_)) + val workOffers = activeExecutors.map { case (id, executorData) => new WorkerOffer(id, executorData.executorHost, executorData.freeCores) - }.toSeq)) + }.toSeq + launchTasks(scheduler.resourceOffers(workOffers)) } override def onDisconnected(remoteAddress: RpcAddress): Unit = { @@ -181,9 +184,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // Make fake resource offers on just one executor private def makeOffers(executorId: String) { - val executorData = executorDataMap(executorId) - launchTasks(scheduler.resourceOffers( - Seq(new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)))) + // Filter out executors under killing + if (!executorsPendingToRemove.contains(executorId)) { + val executorData = executorDataMap(executorId) + val workOffers = Seq( + new WorkerOffer(executorId, executorData.executorHost, executorData.freeCores)) + launchTasks(scheduler.resourceOffers(workOffers)) + } } // Launch tasks returned by a set of resource offers From f67da43c394c27ceb4e6bfd49e81be05e406aa29 Mon Sep 17 00:00:00 2001 From: Ben Date: Tue, 21 Jul 2015 09:51:13 -0700 Subject: [PATCH 25/32] [SPARK-9036] [CORE] SparkListenerExecutorMetricsUpdate messages not included in JsonProtocol This PR implements a JSON serializer and deserializer in the JSONProtocol to handle the (de)serialization of SparkListenerExecutorMetricsUpdate events. It also includes a unit test in the JSONProtocolSuite file. This was implemented to satisfy the improvement request in the JIRA issue SPARK-9036. Author: Ben Closes #7555 from NamelessAnalyst/master and squashes the following commits: fb4e3cc [Ben] Update JSON Protocol and tests aa69517 [Ben] Update JSON Protocol and tests --Corrected Stage Attempt to Stage Attempt ID 33e5774 [Ben] Update JSON Protocol Tests 3f237e7 [Ben] Update JSON Protocol Tests 84ca798 [Ben] Update JSON Protocol Tests cde57a0 [Ben] Update JSON Protocol Tests 8049600 [Ben] Update JSON Protocol Tests c5bc061 [Ben] Update JSON Protocol Tests 6f25785 [Ben] Merge remote-tracking branch 'origin/master' df2a609 [Ben] Update JSON Protocol dcda80b [Ben] Update JSON Protocol --- .../org/apache/spark/util/JsonProtocol.scala | 31 ++++++++- .../apache/spark/util/JsonProtocolSuite.scala | 69 ++++++++++++++++++- 2 files changed, 96 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index adf69a4e78e71..a078f14af52a1 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -92,8 +92,8 @@ private[spark] object JsonProtocol { executorRemovedToJson(executorRemoved) case logStart: SparkListenerLogStart => logStartToJson(logStart) - // These aren't used, but keeps compiler happy - case SparkListenerExecutorMetricsUpdate(_, _) => JNothing + case metricsUpdate: SparkListenerExecutorMetricsUpdate => + executorMetricsUpdateToJson(metricsUpdate) } } @@ -224,6 +224,19 @@ private[spark] object JsonProtocol { ("Spark Version" -> SPARK_VERSION) } + def executorMetricsUpdateToJson(metricsUpdate: SparkListenerExecutorMetricsUpdate): JValue = { + val execId = metricsUpdate.execId + val taskMetrics = metricsUpdate.taskMetrics + ("Event" -> Utils.getFormattedClassName(metricsUpdate)) ~ + ("Executor ID" -> execId) ~ + ("Metrics Updated" -> taskMetrics.map { case (taskId, stageId, stageAttemptId, metrics) => + ("Task ID" -> taskId) ~ + ("Stage ID" -> stageId) ~ + ("Stage Attempt ID" -> stageAttemptId) ~ + ("Task Metrics" -> taskMetricsToJson(metrics)) + }) + } + /** ------------------------------------------------------------------- * * JSON serialization methods for classes SparkListenerEvents depend on | * -------------------------------------------------------------------- */ @@ -463,6 +476,7 @@ private[spark] object JsonProtocol { val executorAdded = Utils.getFormattedClassName(SparkListenerExecutorAdded) val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved) val logStart = Utils.getFormattedClassName(SparkListenerLogStart) + val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate) (json \ "Event").extract[String] match { case `stageSubmitted` => stageSubmittedFromJson(json) @@ -481,6 +495,7 @@ private[spark] object JsonProtocol { case `executorAdded` => executorAddedFromJson(json) case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) + case `metricsUpdate` => executorMetricsUpdateFromJson(json) } } @@ -598,6 +613,18 @@ private[spark] object JsonProtocol { SparkListenerLogStart(sparkVersion) } + def executorMetricsUpdateFromJson(json: JValue): SparkListenerExecutorMetricsUpdate = { + val execInfo = (json \ "Executor ID").extract[String] + val taskMetrics = (json \ "Metrics Updated").extract[List[JValue]].map { json => + val taskId = (json \ "Task ID").extract[Long] + val stageId = (json \ "Stage ID").extract[Int] + val stageAttemptId = (json \ "Stage Attempt ID").extract[Int] + val metrics = taskMetricsFromJson(json \ "Task Metrics") + (taskId, stageId, stageAttemptId, metrics) + } + SparkListenerExecutorMetricsUpdate(execInfo, taskMetrics) + } + /** --------------------------------------------------------------------- * * JSON deserialization methods for classes SparkListenerEvents depend on | * ---------------------------------------------------------------------- */ diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index e0ef9c70a5fc3..dde95f3778434 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -83,6 +83,9 @@ class JsonProtocolSuite extends SparkFunSuite { val executorAdded = SparkListenerExecutorAdded(executorAddedTime, "exec1", new ExecutorInfo("Hostee.awesome.com", 11, logUrlMap)) val executorRemoved = SparkListenerExecutorRemoved(executorRemovedTime, "exec2", "test reason") + val executorMetricsUpdate = SparkListenerExecutorMetricsUpdate("exec3", Seq( + (1L, 2, 3, makeTaskMetrics(300L, 400L, 500L, 600L, 700, 800, + hasHadoopInput = true, hasOutput = true)))) testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -102,6 +105,7 @@ class JsonProtocolSuite extends SparkFunSuite { testEvent(applicationEnd, applicationEndJsonString) testEvent(executorAdded, executorAddedJsonString) testEvent(executorRemoved, executorRemovedJsonString) + testEvent(executorMetricsUpdate, executorMetricsUpdateJsonString) } test("Dependent Classes") { @@ -440,10 +444,20 @@ class JsonProtocolSuite extends SparkFunSuite { case (e1: SparkListenerEnvironmentUpdate, e2: SparkListenerEnvironmentUpdate) => assertEquals(e1.environmentDetails, e2.environmentDetails) case (e1: SparkListenerExecutorAdded, e2: SparkListenerExecutorAdded) => - assert(e1.executorId == e1.executorId) + assert(e1.executorId === e1.executorId) assertEquals(e1.executorInfo, e2.executorInfo) case (e1: SparkListenerExecutorRemoved, e2: SparkListenerExecutorRemoved) => - assert(e1.executorId == e1.executorId) + assert(e1.executorId === e1.executorId) + case (e1: SparkListenerExecutorMetricsUpdate, e2: SparkListenerExecutorMetricsUpdate) => + assert(e1.execId === e2.execId) + assertSeqEquals[(Long, Int, Int, TaskMetrics)](e1.taskMetrics, e2.taskMetrics, (a, b) => { + val (taskId1, stageId1, stageAttemptId1, metrics1) = a + val (taskId2, stageId2, stageAttemptId2, metrics2) = b + assert(taskId1 === taskId2) + assert(stageId1 === stageId2) + assert(stageAttemptId1 === stageAttemptId2) + assertEquals(metrics1, metrics2) + }) case (e1, e2) => assert(e1 === e2) case _ => fail("Events don't match in types!") @@ -1598,4 +1612,55 @@ class JsonProtocolSuite extends SparkFunSuite { | "Removed Reason": "test reason" |} """ + + private val executorMetricsUpdateJsonString = + s""" + |{ + | "Event": "SparkListenerExecutorMetricsUpdate", + | "Executor ID": "exec3", + | "Metrics Updated": [ + | { + | "Task ID": 1, + | "Stage ID": 2, + | "Stage Attempt ID": 3, + | "Task Metrics": { + | "Host Name": "localhost", + | "Executor Deserialize Time": 300, + | "Executor Run Time": 400, + | "Result Size": 500, + | "JVM GC Time": 600, + | "Result Serialization Time": 700, + | "Memory Bytes Spilled": 800, + | "Disk Bytes Spilled": 0, + | "Input Metrics": { + | "Data Read Method": "Hadoop", + | "Bytes Read": 2100, + | "Records Read": 21 + | }, + | "Output Metrics": { + | "Data Write Method": "Hadoop", + | "Bytes Written": 1200, + | "Records Written": 12 + | }, + | "Updated Blocks": [ + | { + | "Block ID": "rdd_0_0", + | "Status": { + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, + | "Use ExternalBlockStore": false, + | "Deserialized": false, + | "Replication": 2 + | }, + | "Memory Size": 0, + | "ExternalBlockStore Size": 0, + | "Disk Size": 0 + | } + | } + | ] + | } + | }] + |} + """.stripMargin } From 9a4fd875b39b6a1ef7038823d1c49b0826110fbc Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 21 Jul 2015 09:52:27 -0700 Subject: [PATCH 26/32] [SPARK-9128] [CORE] Get outerclasses and objects with only one method calling in ClosureCleaner JIRA: https://issues.apache.org/jira/browse/SPARK-9128 Currently, in `ClosureCleaner`, the outerclasses and objects are retrieved using two different methods. However, the logic of the two methods is the same, and we can get both the outerclasses and objects with only one method calling. Author: Liang-Chi Hsieh Closes #7459 from viirya/remove_extra_closurecleaner and squashes the following commits: 7c9858d [Liang-Chi Hsieh] For comments. a096941 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into remove_extra_closurecleaner 2ec5ce1 [Liang-Chi Hsieh] Remove unnecessary methods. 4df5a51 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into remove_extra_closurecleaner dc110d1 [Liang-Chi Hsieh] Add method to get outerclasses and objects at the same time. --- .../apache/spark/util/ClosureCleaner.scala | 32 +++-------- .../spark/util/ClosureCleanerSuite2.scala | 54 ++++++++----------- 2 files changed, 28 insertions(+), 58 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala index 43626b4ef4880..ebead830c6466 100644 --- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala @@ -49,45 +49,28 @@ private[spark] object ClosureCleaner extends Logging { cls.getName.contains("$anonfun$") } - // Get a list of the classes of the outer objects of a given closure object, obj; + // Get a list of the outer objects and their classes of a given closure object, obj; // the outer objects are defined as any closures that obj is nested within, plus // possibly the class that the outermost closure is in, if any. We stop searching // for outer objects beyond that because cloning the user's object is probably // not a good idea (whereas we can clone closure objects just fine since we // understand how all their fields are used). - private def getOuterClasses(obj: AnyRef): List[Class[_]] = { + private def getOuterClassesAndObjects(obj: AnyRef): (List[Class[_]], List[AnyRef]) = { for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { f.setAccessible(true) val outer = f.get(obj) // The outer pointer may be null if we have cleaned this closure before if (outer != null) { if (isClosure(f.getType)) { - return f.getType :: getOuterClasses(outer) + val recurRet = getOuterClassesAndObjects(outer) + return (f.getType :: recurRet._1, outer :: recurRet._2) } else { - return f.getType :: Nil // Stop at the first $outer that is not a closure + return (f.getType :: Nil, outer :: Nil) // Stop at the first $outer that is not a closure } } } - Nil + (Nil, Nil) } - - // Get a list of the outer objects for a given closure object. - private def getOuterObjects(obj: AnyRef): List[AnyRef] = { - for (f <- obj.getClass.getDeclaredFields if f.getName == "$outer") { - f.setAccessible(true) - val outer = f.get(obj) - // The outer pointer may be null if we have cleaned this closure before - if (outer != null) { - if (isClosure(f.getType)) { - return outer :: getOuterObjects(outer) - } else { - return outer :: Nil // Stop at the first $outer that is not a closure - } - } - } - Nil - } - /** * Return a list of classes that represent closures enclosed in the given closure object. */ @@ -205,8 +188,7 @@ private[spark] object ClosureCleaner extends Logging { // A list of enclosing objects and their respective classes, from innermost to outermost // An outer object at a given index is of type outer class at the same index - val outerClasses = getOuterClasses(func) - val outerObjects = getOuterObjects(func) + val (outerClasses, outerObjects) = getOuterClassesAndObjects(func) // For logging purposes only val declaredFields = func.getClass.getDeclaredFields diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala index 3147c937769d2..a829b099025e9 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -120,8 +120,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri // Accessors for private methods private val _isClosure = PrivateMethod[Boolean]('isClosure) private val _getInnerClosureClasses = PrivateMethod[List[Class[_]]]('getInnerClosureClasses) - private val _getOuterClasses = PrivateMethod[List[Class[_]]]('getOuterClasses) - private val _getOuterObjects = PrivateMethod[List[AnyRef]]('getOuterObjects) + private val _getOuterClassesAndObjects = + PrivateMethod[(List[Class[_]], List[AnyRef])]('getOuterClassesAndObjects) private def isClosure(obj: AnyRef): Boolean = { ClosureCleaner invokePrivate _isClosure(obj) @@ -131,12 +131,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri ClosureCleaner invokePrivate _getInnerClosureClasses(closure) } - private def getOuterClasses(closure: AnyRef): List[Class[_]] = { - ClosureCleaner invokePrivate _getOuterClasses(closure) - } - - private def getOuterObjects(closure: AnyRef): List[AnyRef] = { - ClosureCleaner invokePrivate _getOuterObjects(closure) + private def getOuterClassesAndObjects(closure: AnyRef): (List[Class[_]], List[AnyRef]) = { + ClosureCleaner invokePrivate _getOuterClassesAndObjects(closure) } test("get inner closure classes") { @@ -171,14 +167,11 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val closure2 = () => localValue val closure3 = () => someSerializableValue val closure4 = () => someSerializableMethod() - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerClasses3 = getOuterClasses(closure3) - val outerClasses4 = getOuterClasses(closure4) - val outerObjects1 = getOuterObjects(closure1) - val outerObjects2 = getOuterObjects(closure2) - val outerObjects3 = getOuterObjects(closure3) - val outerObjects4 = getOuterObjects(closure4) + + val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1) + val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2) + val (outerClasses3, outerObjects3) = getOuterClassesAndObjects(closure3) + val (outerClasses4, outerObjects4) = getOuterClassesAndObjects(closure4) // The classes and objects should have the same size assert(outerClasses1.size === outerObjects1.size) @@ -211,10 +204,8 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val x = 1 val closure1 = () => 1 val closure2 = () => x - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerObjects1 = getOuterObjects(closure1) - val outerObjects2 = getOuterObjects(closure2) + val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1) + val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2) assert(outerClasses1.size === outerObjects1.size) assert(outerClasses2.size === outerObjects2.size) // These inner closures only reference local variables, and so do not have $outer pointers @@ -227,12 +218,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val closure1 = () => 1 val closure2 = () => y val closure3 = () => localValue - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerClasses3 = getOuterClasses(closure3) - val outerObjects1 = getOuterObjects(closure1) - val outerObjects2 = getOuterObjects(closure2) - val outerObjects3 = getOuterObjects(closure3) + val (outerClasses1, outerObjects1) = getOuterClassesAndObjects(closure1) + val (outerClasses2, outerObjects2) = getOuterClassesAndObjects(closure2) + val (outerClasses3, outerObjects3) = getOuterClassesAndObjects(closure3) assert(outerClasses1.size === outerObjects1.size) assert(outerClasses2.size === outerObjects2.size) assert(outerClasses3.size === outerObjects3.size) @@ -265,9 +253,9 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val closure1 = () => 1 val closure2 = () => localValue val closure3 = () => someSerializableValue - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerClasses3 = getOuterClasses(closure3) + val (outerClasses1, _) = getOuterClassesAndObjects(closure1) + val (outerClasses2, _) = getOuterClassesAndObjects(closure2) + val (outerClasses3, _) = getOuterClassesAndObjects(closure3) val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false) val fields2 = findAccessedFields(closure2, outerClasses2, findTransitively = false) @@ -307,10 +295,10 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri val closure2 = () => a val closure3 = () => localValue val closure4 = () => someSerializableValue - val outerClasses1 = getOuterClasses(closure1) - val outerClasses2 = getOuterClasses(closure2) - val outerClasses3 = getOuterClasses(closure3) - val outerClasses4 = getOuterClasses(closure4) + val (outerClasses1, _) = getOuterClassesAndObjects(closure1) + val (outerClasses2, _) = getOuterClassesAndObjects(closure2) + val (outerClasses3, _) = getOuterClassesAndObjects(closure3) + val (outerClasses4, _) = getOuterClassesAndObjects(closure4) // First, find only fields accessed directly, not transitively, by these closures val fields1 = findAccessedFields(closure1, outerClasses1, findTransitively = false) From 31954910d67c29874d2af22ee30590a7346a464c Mon Sep 17 00:00:00 2001 From: Jacek Lewandowski Date: Tue, 21 Jul 2015 09:53:33 -0700 Subject: [PATCH 27/32] [SPARK-7171] Added a method to retrieve metrics sources in TaskContext Author: Jacek Lewandowski Closes #5805 from jacek-lewandowski/SPARK-7171 and squashes the following commits: ed20bda [Jacek Lewandowski] SPARK-7171: Added a method to retrieve metrics sources in TaskContext --- .../scala/org/apache/spark/TaskContext.scala | 9 ++++++++ .../org/apache/spark/TaskContextImpl.scala | 6 +++++ .../org/apache/spark/executor/Executor.scala | 5 ++++- .../apache/spark/metrics/MetricsSystem.scala | 3 +++ .../apache/spark/scheduler/DAGScheduler.scala | 1 + .../org/apache/spark/scheduler/Task.scala | 8 ++++++- .../java/org/apache/spark/JavaAPISuite.java | 2 +- .../org/apache/spark/CacheManagerSuite.scala | 8 +++---- .../org/apache/spark/rdd/PipedRDDSuite.scala | 2 +- .../spark/scheduler/TaskContextSuite.scala | 22 ++++++++++++++++--- .../shuffle/hash/HashShuffleReaderSuite.scala | 2 +- .../ShuffleBlockFetcherIteratorSuite.scala | 6 ++--- 12 files changed, 59 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index e93eb93124e51..b48836d5c8897 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -21,6 +21,7 @@ import java.io.Serializable import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics +import org.apache.spark.metrics.source.Source import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.TaskCompletionListener @@ -148,6 +149,14 @@ abstract class TaskContext extends Serializable { @DeveloperApi def taskMetrics(): TaskMetrics + /** + * ::DeveloperApi:: + * Returns all metrics sources with the given name which are associated with the instance + * which runs the task. For more information see [[org.apache.spark.metrics.MetricsSystem!]]. + */ + @DeveloperApi + def getMetricsSources(sourceName: String): Seq[Source] + /** * Returns the manager for this task's managed memory. */ diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 6e394f1b12445..9ee168ae016f8 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -20,6 +20,8 @@ package org.apache.spark import scala.collection.mutable.{ArrayBuffer, HashMap} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.metrics.source.Source import org.apache.spark.unsafe.memory.TaskMemoryManager import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} @@ -29,6 +31,7 @@ private[spark] class TaskContextImpl( override val taskAttemptId: Long, override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, + @transient private val metricsSystem: MetricsSystem, val runningLocally: Boolean = false, val taskMetrics: TaskMetrics = TaskMetrics.empty) extends TaskContext @@ -95,6 +98,9 @@ private[spark] class TaskContextImpl( override def isInterrupted(): Boolean = interrupted + override def getMetricsSources(sourceName: String): Seq[Source] = + metricsSystem.getSourcesByName(sourceName) + @transient private val accumulators = new HashMap[Long, Accumulable[_, _]] private[spark] override def registerAccumulator(a: Accumulable[_, _]): Unit = synchronized { diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9087debde8c41..66624ffbe4790 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -210,7 +210,10 @@ private[spark] class Executor( // Run the actual task and measure its runtime. taskStart = System.currentTimeMillis() val (value, accumUpdates) = try { - task.run(taskAttemptId = taskId, attemptNumber = attemptNumber) + task.run( + taskAttemptId = taskId, + attemptNumber = attemptNumber, + metricsSystem = env.metricsSystem) } finally { // Note: this memory freeing logic is duplicated in DAGScheduler.runLocallyWithinThread; // when changing this, make sure to update both copies. diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 67f64d5e278de..4517f465ebd3b 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -142,6 +142,9 @@ private[spark] class MetricsSystem private ( } else { defaultName } } + def getSourcesByName(sourceName: String): Seq[Source] = + sources.filter(_.sourceName == sourceName) + def registerSource(source: Source) { sources += source try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 71a219a4f3414..b829d06923404 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -682,6 +682,7 @@ class DAGScheduler( taskAttemptId = 0, attemptNumber = 0, taskMemoryManager = taskMemoryManager, + metricsSystem = env.metricsSystem, runningLocally = true) TaskContext.setTaskContext(taskContext) try { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 76a19aeac4679..d11a00956a9a9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -22,6 +22,7 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.{TaskContextImpl, TaskContext} import org.apache.spark.executor.TaskMetrics import org.apache.spark.serializer.SerializerInstance @@ -61,13 +62,18 @@ private[spark] abstract class Task[T]( * @param attemptNumber how many times this task has been attempted (0 for the first attempt) * @return the result of the task along with updates of Accumulators. */ - final def run(taskAttemptId: Long, attemptNumber: Int): (T, AccumulatorUpdates) = { + final def run( + taskAttemptId: Long, + attemptNumber: Int, + metricsSystem: MetricsSystem) + : (T, AccumulatorUpdates) = { context = new TaskContextImpl( stageId = stageId, partitionId = partitionId, taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, taskMemoryManager = taskMemoryManager, + metricsSystem = metricsSystem, runningLocally = false) TaskContext.setTaskContext(context) context.taskMetrics.setHostname(Utils.localHostName()) diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index dfd86d3e51e7d..1b04a3b1cff0e 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -1011,7 +1011,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, false, new TaskMetrics()); + TaskContext context = new TaskContextImpl(0, 0, 0L, 0, null, null, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index af81e46a657d3..618a5fb24710f 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -65,7 +65,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before // in blockManager.put is a losing battle. You have been warned. blockManager = sc.env.blockManager cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0, null) + val context = new TaskContextImpl(0, 0, 0, 0, null, null) val computeValue = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) val getValue = blockManager.get(RDDBlockId(rdd.id, split.index)) assert(computeValue.toList === List(1, 2, 3, 4)) @@ -77,7 +77,7 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before val result = new BlockResult(Array(5, 6, 7).iterator, DataReadMethod.Memory, 12) when(blockManager.get(RDDBlockId(0, 0))).thenReturn(Some(result)) - val context = new TaskContextImpl(0, 0, 0, 0, null) + val context = new TaskContextImpl(0, 0, 0, 0, null, null) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(5, 6, 7)) } @@ -86,14 +86,14 @@ class CacheManagerSuite extends SparkFunSuite with LocalSparkContext with Before // Local computation should not persist the resulting value, so don't expect a put(). when(blockManager.get(RDDBlockId(0, 0))).thenReturn(None) - val context = new TaskContextImpl(0, 0, 0, 0, null, true) + val context = new TaskContextImpl(0, 0, 0, 0, null, null, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } test("verify task metrics updated correctly") { cacheManager = sc.env.cacheManager - val context = new TaskContextImpl(0, 0, 0, 0, null) + val context = new TaskContextImpl(0, 0, 0, 0, null, null) cacheManager.getOrCompute(rdd3, split, context, StorageLevel.MEMORY_ONLY) assert(context.taskMetrics.updatedBlocks.getOrElse(Seq()).size === 2) } diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 32f04d54eff94..3e8816a4c65be 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -175,7 +175,7 @@ class PipedRDDSuite extends SparkFunSuite with SharedSparkContext { } val hadoopPart1 = generateFakeHadoopPartition() val pipedRdd = new PipedRDD(nums, "printenv " + varName) - val tContext = new TaskContextImpl(0, 0, 0, 0, null) + val tContext = new TaskContextImpl(0, 0, 0, 0, null, null) val rddIter = pipedRdd.compute(hadoopPart1, tContext) val arr = rddIter.toArray assert(arr(0) == "/some/path") diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index b9b0eccb0d834..9201d1e1f328b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -24,11 +24,27 @@ import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener} +import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} +import org.apache.spark.metrics.source.JvmSource class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { + test("provide metrics sources") { + val filePath = getClass.getClassLoader.getResource("test_metrics_config.properties").getFile + val conf = new SparkConf(loadDefaults = false) + .set("spark.metrics.conf", filePath) + sc = new SparkContext("local", "test", conf) + val rdd = sc.makeRDD(1 to 1) + val result = sc.runJob(rdd, (tc: TaskContext, it: Iterator[Int]) => { + tc.getMetricsSources("jvm").count { + case source: JvmSource => true + case _ => false + } + }).sum + assert(result > 0) + } + test("calls TaskCompletionListener after failure") { TaskContextSuite.completed = false sc = new SparkContext("local", "test") @@ -44,13 +60,13 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val task = new ResultTask[String, String](0, 0, sc.broadcast(closureSerializer.serialize((rdd, func)).array), rdd.partitions(0), Seq(), 0) intercept[RuntimeException] { - task.run(0, 0) + task.run(0, 0, null) } assert(TaskContextSuite.completed === true) } test("all TaskCompletionListeners should be called even if some fail") { - val context = new TaskContextImpl(0, 0, 0, 0, null) + val context = new TaskContextImpl(0, 0, 0, 0, null, null) val listener = mock(classOf[TaskCompletionListener]) context.addTaskCompletionListener(_ => throw new Exception("blah")) context.addTaskCompletionListener(listener) diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala index 6c9cb448e7833..db718ecabbdb9 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleReaderSuite.scala @@ -138,7 +138,7 @@ class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext { shuffleHandle, reduceId, reduceId + 1, - new TaskContextImpl(0, 0, 0, 0, null), + new TaskContextImpl(0, 0, 0, 0, null, null), blockManager, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 64f3fbdcebed9..cf8bd8ae69625 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -95,7 +95,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ) val iterator = new ShuffleBlockFetcherIterator( - new TaskContextImpl(0, 0, 0, 0, null), + new TaskContextImpl(0, 0, 0, 0, null, null), transfer, blockManager, blocksByAddress, @@ -165,7 +165,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null) + val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, @@ -227,7 +227,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)) - val taskContext = new TaskContextImpl(0, 0, 0, 0, null) + val taskContext = new TaskContextImpl(0, 0, 0, 0, null, null) val iterator = new ShuffleBlockFetcherIterator( taskContext, transfer, From 4f7f1ee378e80b33686508d56e133fc25dec5316 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 21 Jul 2015 09:54:39 -0700 Subject: [PATCH 28/32] [SPARK-4598] [WEBUI] Task table pagination for the Stage page This PR adds pagination for the task table to solve the scalability issue of the stage page. Here is the initial screenshot: pagination The task table only shows 100 tasks. There is a page navigation above the table. Users can click the page navigation or type the page number to jump to another page. The table can be sorted by clicking the headers. However, unlike previous implementation, the sorting work is done in the server now. So clicking a table column to sort needs to refresh the web page. Author: zsxwing Closes #7399 from zsxwing/task-table-pagination and squashes the following commits: 144f513 [zsxwing] Display the page navigation when the page number is out of range a3eee22 [zsxwing] Add extra space for the error message 54c5b84 [zsxwing] Reset page to 1 if the user changes the page size c2f7f39 [zsxwing] Add a text field to let users fill the page size bad52eb [zsxwing] Display user-friendly error messages 410586b [zsxwing] Scroll down to the tasks table if the url contains any sort column a0746d1 [zsxwing] Use expand-dag-viz-arrow-job and expand-dag-viz-arrow-stage instead of expand-dag-viz-arrow-true and expand-dag-viz-arrow-false b123f67 [zsxwing] Use localStorage to remember the user's actions and replay them when loading the page 894a342 [zsxwing] Show the link cursor when hovering for headers and page links and other minor fix 4d4fecf [zsxwing] Address Carson's comments d9285f0 [zsxwing] Add comments and fix the style 74285fa [zsxwing] Merge branch 'master' into task-table-pagination db6c859 [zsxwing] Task table pagination for the Stage page --- .../spark/ui/static/additional-metrics.js | 34 +- .../apache/spark/ui/static/spark-dag-viz.js | 27 + .../apache/spark/ui/static/timeline-view.js | 39 + .../org/apache/spark/ui/PagedTable.scala | 246 +++++ .../org/apache/spark/ui/jobs/StagePage.scala | 879 +++++++++++++----- .../org/apache/spark/ui/PagedTableSuite.scala | 99 ++ 6 files changed, 1102 insertions(+), 222 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/ui/PagedTable.scala create mode 100644 core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js index 0b450dc76bc38..3c8ddddf07b1e 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js +++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js @@ -19,6 +19,9 @@ * to be registered after the page loads. */ $(function() { $("span.expand-additional-metrics").click(function(){ + var status = window.localStorage.getItem("expand-additional-metrics") == "true"; + status = !status; + // Expand the list of additional metrics. var additionalMetricsDiv = $(this).parent().find('.additional-metrics'); $(additionalMetricsDiv).toggleClass('collapsed'); @@ -26,17 +29,31 @@ $(function() { // Switch the class of the arrow from open to closed. $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open'); $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed'); + + window.localStorage.setItem("expand-additional-metrics", "" + status); }); + if (window.localStorage.getItem("expand-additional-metrics") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-additional-metrics", "false"); + $("span.expand-additional-metrics").trigger("click"); + } + stripeSummaryTable(); $('input[type="checkbox"]').click(function() { - var column = "table ." + $(this).attr("name"); + var name = $(this).attr("name") + var column = "table ." + name; + var status = window.localStorage.getItem(name) == "true"; + status = !status; $(column).toggle(); stripeSummaryTable(); + window.localStorage.setItem(name, "" + status); }); $("#select-all-metrics").click(function() { + var status = window.localStorage.getItem("select-all-metrics") == "true"; + status = !status; if (this.checked) { // Toggle all un-checked options. $('input[type="checkbox"]:not(:checked)').trigger('click'); @@ -44,6 +61,21 @@ $(function() { // Toggle all checked options. $('input[type="checkbox"]:checked').trigger('click'); } + window.localStorage.setItem("select-all-metrics", "" + status); + }); + + if (window.localStorage.getItem("select-all-metrics") == "true") { + $("#select-all-metrics").attr('checked', status); + } + + $("span.additional-metric-title").parent().find('input[type="checkbox"]').each(function() { + var name = $(this).attr("name") + // If name is undefined, then skip it because it's the "select-all-metrics" checkbox + if (name && window.localStorage.getItem(name) == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem(name, "false"); + $(this).trigger("click") + } }); // Trigger a click on the checkbox if a user clicks the label next to it. diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js index 9fa53baaf4212..4a893bc0189aa 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js +++ b/core/src/main/resources/org/apache/spark/ui/static/spark-dag-viz.js @@ -72,6 +72,14 @@ var StagePageVizConstants = { rankSep: 40 }; +/* + * Return "expand-dag-viz-arrow-job" if forJob is true. + * Otherwise, return "expand-dag-viz-arrow-stage". + */ +function expandDagVizArrowKey(forJob) { + return forJob ? "expand-dag-viz-arrow-job" : "expand-dag-viz-arrow-stage"; +} + /* * Show or hide the RDD DAG visualization. * @@ -79,6 +87,9 @@ var StagePageVizConstants = { * This is the narrow interface called from the Scala UI code. */ function toggleDagViz(forJob) { + var status = window.localStorage.getItem(expandDagVizArrowKey(forJob)) == "true"; + status = !status; + var arrowSelector = ".expand-dag-viz-arrow"; $(arrowSelector).toggleClass('arrow-closed'); $(arrowSelector).toggleClass('arrow-open'); @@ -93,8 +104,24 @@ function toggleDagViz(forJob) { // Save the graph for later so we don't have to render it again graphContainer().style("display", "none"); } + + window.localStorage.setItem(expandDagVizArrowKey(forJob), "" + status); } +$(function (){ + if (window.localStorage.getItem(expandDagVizArrowKey(false)) == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem(expandDagVizArrowKey(false), "false"); + toggleDagViz(false); + } + + if (window.localStorage.getItem(expandDagVizArrowKey(true)) == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem(expandDagVizArrowKey(true), "false"); + toggleDagViz(true); + } +}); + /* * Render the RDD DAG visualization. * diff --git a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js index ca74ef9d7e94e..f4453c71df1ea 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js +++ b/core/src/main/resources/org/apache/spark/ui/static/timeline-view.js @@ -66,14 +66,27 @@ function drawApplicationTimeline(groupArray, eventObjArray, startTime) { setupJobEventAction(); $("span.expand-application-timeline").click(function() { + var status = window.localStorage.getItem("expand-application-timeline") == "true"; + status = !status; + $("#application-timeline").toggleClass('collapsed'); // Switch the class of the arrow from open to closed. $(this).find('.expand-application-timeline-arrow').toggleClass('arrow-open'); $(this).find('.expand-application-timeline-arrow').toggleClass('arrow-closed'); + + window.localStorage.setItem("expand-application-timeline", "" + status); }); } +$(function (){ + if (window.localStorage.getItem("expand-application-timeline") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-application-timeline", "false"); + $("span.expand-application-timeline").trigger('click'); + } +}); + function drawJobTimeline(groupArray, eventObjArray, startTime) { var groups = new vis.DataSet(groupArray); var items = new vis.DataSet(eventObjArray); @@ -125,14 +138,27 @@ function drawJobTimeline(groupArray, eventObjArray, startTime) { setupStageEventAction(); $("span.expand-job-timeline").click(function() { + var status = window.localStorage.getItem("expand-job-timeline") == "true"; + status = !status; + $("#job-timeline").toggleClass('collapsed'); // Switch the class of the arrow from open to closed. $(this).find('.expand-job-timeline-arrow').toggleClass('arrow-open'); $(this).find('.expand-job-timeline-arrow').toggleClass('arrow-closed'); + + window.localStorage.setItem("expand-job-timeline", "" + status); }); } +$(function (){ + if (window.localStorage.getItem("expand-job-timeline") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-job-timeline", "false"); + $("span.expand-job-timeline").trigger('click'); + } +}); + function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, maxFinishTime) { var groups = new vis.DataSet(groupArray); var items = new vis.DataSet(eventObjArray); @@ -176,14 +202,27 @@ function drawTaskAssignmentTimeline(groupArray, eventObjArray, minLaunchTime, ma setupZoomable("#task-assignment-timeline-zoom-lock", taskTimeline); $("span.expand-task-assignment-timeline").click(function() { + var status = window.localStorage.getItem("expand-task-assignment-timeline") == "true"; + status = !status; + $("#task-assignment-timeline").toggleClass("collapsed"); // Switch the class of the arrow from open to closed. $(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-open"); $(this).find(".expand-task-assignment-timeline-arrow").toggleClass("arrow-closed"); + + window.localStorage.setItem("expand-task-assignment-timeline", "" + status); }); } +$(function (){ + if (window.localStorage.getItem("expand-task-assignment-timeline") == "true") { + // Set it to false so that the click function can revert it + window.localStorage.setItem("expand-task-assignment-timeline", "false"); + $("span.expand-task-assignment-timeline").trigger('click'); + } +}); + function setupExecutorEventAction() { $(".item.box.executor").each(function () { $(this).hover( diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala new file mode 100644 index 0000000000000..17d7b39c2d951 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import scala.xml.{Node, Unparsed} + +/** + * A data source that provides data for a page. + * + * @param pageSize the number of rows in a page + */ +private[ui] abstract class PagedDataSource[T](val pageSize: Int) { + + if (pageSize <= 0) { + throw new IllegalArgumentException("Page size must be positive") + } + + /** + * Return the size of all data. + */ + protected def dataSize: Int + + /** + * Slice a range of data. + */ + protected def sliceData(from: Int, to: Int): Seq[T] + + /** + * Slice the data for this page + */ + def pageData(page: Int): PageData[T] = { + val totalPages = (dataSize + pageSize - 1) / pageSize + if (page <= 0 || page > totalPages) { + throw new IndexOutOfBoundsException( + s"Page $page is out of range. Please select a page number between 1 and $totalPages.") + } + val from = (page - 1) * pageSize + val to = dataSize.min(page * pageSize) + PageData(totalPages, sliceData(from, to)) + } + +} + +/** + * The data returned by `PagedDataSource.pageData`, including the page number, the number of total + * pages and the data in this page. + */ +private[ui] case class PageData[T](totalPage: Int, data: Seq[T]) + +/** + * A paged table that will generate a HTML table for a specified page and also the page navigation. + */ +private[ui] trait PagedTable[T] { + + def tableId: String + + def tableCssClass: String + + def dataSource: PagedDataSource[T] + + def headers: Seq[Node] + + def row(t: T): Seq[Node] + + def table(page: Int): Seq[Node] = { + val _dataSource = dataSource + try { + val PageData(totalPages, data) = _dataSource.pageData(page) +
+ {pageNavigation(page, _dataSource.pageSize, totalPages)} + + {headers} + + {data.map(row)} + +
+
+ } catch { + case e: IndexOutOfBoundsException => + val PageData(totalPages, _) = _dataSource.pageData(1) +
+ {pageNavigation(1, _dataSource.pageSize, totalPages)} +
{e.getMessage}
+
+ } + } + + /** + * Return a page navigation. + *
    + *
  • If the totalPages is 1, the page navigation will be empty
  • + *
  • + * If the totalPages is more than 1, it will create a page navigation including a group of + * page numbers and a form to submit the page number. + *
  • + *
+ * + * Here are some examples of the page navigation: + * {{{ + * << < 11 12 13* 14 15 16 17 18 19 20 > >> + * + * This is the first group, so "<<" is hidden. + * < 1 2* 3 4 5 6 7 8 9 10 > >> + * + * This is the first group and the first page, so "<<" and "<" are hidden. + * 1* 2 3 4 5 6 7 8 9 10 > >> + * + * Assume totalPages is 19. This is the last group, so ">>" is hidden. + * << < 11 12 13* 14 15 16 17 18 19 > + * + * Assume totalPages is 19. This is the last group and the last page, so ">>" and ">" are hidden. + * << < 11 12 13 14 15 16 17 18 19* + * + * * means the current page number + * << means jumping to the first page of the previous group. + * < means jumping to the previous page. + * >> means jumping to the first page of the next group. + * > means jumping to the next page. + * }}} + */ + private[ui] def pageNavigation(page: Int, pageSize: Int, totalPages: Int): Seq[Node] = { + if (totalPages == 1) { + Nil + } else { + // A group includes all page numbers will be shown in the page navigation. + // The size of group is 10 means there are 10 page numbers will be shown. + // The first group is 1 to 10, the second is 2 to 20, and so on + val groupSize = 10 + val firstGroup = 0 + val lastGroup = (totalPages - 1) / groupSize + val currentGroup = (page - 1) / groupSize + val startPage = currentGroup * groupSize + 1 + val endPage = totalPages.min(startPage + groupSize - 1) + val pageTags = (startPage to endPage).map { p => + if (p == page) { + // The current page should be disabled so that it cannot be clicked. +
  • {p}
  • + } else { +
  • {p}
  • + } + } + val (goButtonJsFuncName, goButtonJsFunc) = goButtonJavascriptFunction + // When clicking the "Go" button, it will call this javascript method and then call + // "goButtonJsFuncName" + val formJs = + s"""$$(function(){ + | $$( "#form-task-page" ).submit(function(event) { + | var page = $$("#form-task-page-no").val() + | var pageSize = $$("#form-task-page-size").val() + | pageSize = pageSize ? pageSize: 100; + | if (page != "") { + | ${goButtonJsFuncName}(page, pageSize); + | } + | event.preventDefault(); + | }); + |}); + """.stripMargin + +
    +
    +
    + + + + + + +
    +
    + + +
    + } + } + + /** + * Return a link to jump to a page. + */ + def pageLink(page: Int): String + + /** + * Only the implementation knows how to create the url with a page number and the page size, so we + * leave this one to the implementation. The implementation should create a JavaScript method that + * accepts a page number along with the page size and jumps to the page. The return value is this + * method name and its JavaScript codes. + */ + def goButtonJavascriptFunction: (String, String) +} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 6e077bf3e70d5..cf04b5e59239b 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -17,6 +17,7 @@ package org.apache.spark.ui.jobs +import java.net.URLEncoder import java.util.Date import javax.servlet.http.HttpServletRequest @@ -27,13 +28,14 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} -import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils} +import org.apache.spark.ui._ import org.apache.spark.ui.jobs.UIData._ -import org.apache.spark.ui.scope.RDDOperationGraph import org.apache.spark.util.{Utils, Distribution} /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { + import StagePage._ + private val progressListener = parent.progressListener private val operationGraphListener = parent.operationGraphListener @@ -74,6 +76,16 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val parameterAttempt = request.getParameter("attempt") require(parameterAttempt != null && parameterAttempt.nonEmpty, "Missing attempt parameter") + val parameterTaskPage = request.getParameter("task.page") + val parameterTaskSortColumn = request.getParameter("task.sort") + val parameterTaskSortDesc = request.getParameter("task.desc") + val parameterTaskPageSize = request.getParameter("task.pageSize") + + val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) + val taskSortColumn = Option(parameterTaskSortColumn).getOrElse("Index") + val taskSortDesc = Option(parameterTaskSortDesc).map(_.toBoolean).getOrElse(false) + val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100) + // If this is set, expand the dag visualization by default val expandDagVizParam = request.getParameter("expandDagViz") val expandDagViz = expandDagVizParam != null && expandDagVizParam.toBoolean @@ -231,52 +243,47 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { accumulableRow, accumulables.values.toSeq) - val taskHeadersAndCssClasses: Seq[(String, String)] = - Seq( - ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), - ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""), - ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), - ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), - ("GC Time", ""), - ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), - ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ - {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ - {if (stageData.hasInput) Seq(("Input Size / Records", "")) else Nil} ++ - {if (stageData.hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ - {if (stageData.hasShuffleRead) { - Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), - ("Shuffle Read Size / Records", ""), - ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) - } else { - Nil - }} ++ - {if (stageData.hasShuffleWrite) { - Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) - } else { - Nil - }} ++ - {if (stageData.hasBytesSpilled) { - Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) - } else { - Nil - }} ++ - Seq(("Errors", "")) - - val unzipped = taskHeadersAndCssClasses.unzip - val currentTime = System.currentTimeMillis() - val taskTable = UIUtils.listingTable( - unzipped._1, - taskRow( + val (taskTable, taskTableHTML) = try { + val _taskTable = new TaskPagedTable( + UIUtils.prependBaseUri(parent.basePath) + + s"/stages/stage?id=${stageId}&attempt=${stageAttemptId}", + tasks, hasAccumulators, stageData.hasInput, stageData.hasOutput, stageData.hasShuffleRead, stageData.hasShuffleWrite, stageData.hasBytesSpilled, - currentTime), - tasks, - headerClasses = unzipped._2) + currentTime, + pageSize = taskPageSize, + sortColumn = taskSortColumn, + desc = taskSortDesc + ) + (_taskTable, _taskTable.table(taskPage)) + } catch { + case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => + (null,
    {e.getMessage}
    ) + } + + val jsForScrollingDownToTaskTable = + + + val taskIdsInPage = if (taskTable == null) Set.empty[Long] + else taskTable.dataSource.slicedTaskIds + // Excludes tasks which failed and have incomplete metrics val validTasks = tasks.filter(t => t.taskInfo.status == "SUCCESS" && t.taskMetrics.isDefined) @@ -499,12 +506,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { dagViz ++ maybeExpandDagViz ++ showAdditionalMetrics ++ - makeTimeline(stageData.taskData.values.toSeq, currentTime) ++ + makeTimeline( + // Only show the tasks in the table + stageData.taskData.values.toSeq.filter(t => taskIdsInPage.contains(t.taskInfo.taskId)), + currentTime) ++

    Summary Metrics for {numCompleted} Completed Tasks

    ++
    {summaryTable.getOrElse("No tasks have reported metrics yet.")}
    ++

    Aggregated Metrics by Executor

    ++ executorTable.toNodeSeq ++ maybeAccumulableTable ++ -

    Tasks

    ++ taskTable +

    Tasks

    ++ taskTableHTML ++ jsForScrollingDownToTaskTable UIUtils.headerSparkPage(stageHeader, content, parent, showVisualization = true) } } @@ -679,164 +689,619 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } - def taskRow( - hasAccumulators: Boolean, - hasInput: Boolean, - hasOutput: Boolean, - hasShuffleRead: Boolean, - hasShuffleWrite: Boolean, - hasBytesSpilled: Boolean, - currentTime: Long)(taskData: TaskUIData): Seq[Node] = { - taskData match { case TaskUIData(info, metrics, errorMessage) => - val duration = if (info.status == "RUNNING") info.timeRunning(currentTime) - else metrics.map(_.executorRunTime).getOrElse(1L) - val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) - else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") - val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L) - val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) - val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) - val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) - val gettingResultTime = getGettingResultTime(info, currentTime) - - val maybeAccumulators = info.accumulables - val accumulatorsReadable = maybeAccumulators.map { acc => - StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}") +} + +private[ui] object StagePage { + private[ui] def getGettingResultTime(info: TaskInfo, currentTime: Long): Long = { + if (info.gettingResult) { + if (info.finished) { + info.finishTime - info.gettingResultTime + } else { + // The task is still fetching the result. + currentTime - info.gettingResultTime } + } else { + 0L + } + } - val maybeInput = metrics.flatMap(_.inputMetrics) - val inputSortable = maybeInput.map(_.bytesRead.toString).getOrElse("") - val inputReadable = maybeInput - .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})") - .getOrElse("") - val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("") - - val maybeOutput = metrics.flatMap(_.outputMetrics) - val outputSortable = maybeOutput.map(_.bytesWritten.toString).getOrElse("") - val outputReadable = maybeOutput - .map(m => s"${Utils.bytesToString(m.bytesWritten)}") - .getOrElse("") - val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("") - - val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics) - val shuffleReadBlockedTimeSortable = maybeShuffleRead - .map(_.fetchWaitTime.toString).getOrElse("") - val shuffleReadBlockedTimeReadable = - maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") - - val totalShuffleBytes = maybeShuffleRead.map(_.totalBytesRead) - val shuffleReadSortable = totalShuffleBytes.map(_.toString).getOrElse("") - val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") - val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") - - val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead) - val shuffleReadRemoteSortable = remoteShuffleBytes.map(_.toString).getOrElse("") - val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") - - val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics) - val shuffleWriteSortable = maybeShuffleWrite.map(_.shuffleBytesWritten.toString).getOrElse("") - val shuffleWriteReadable = maybeShuffleWrite - .map(m => s"${Utils.bytesToString(m.shuffleBytesWritten)}").getOrElse("") - val shuffleWriteRecords = maybeShuffleWrite - .map(_.shuffleRecordsWritten.toString).getOrElse("") - - val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleWriteTime) - val writeTimeSortable = maybeWriteTime.map(_.toString).getOrElse("") - val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms => - if (ms == 0) "" else UIUtils.formatDuration(ms) - }.getOrElse("") - - val maybeMemoryBytesSpilled = metrics.map(_.memoryBytesSpilled) - val memoryBytesSpilledSortable = maybeMemoryBytesSpilled.map(_.toString).getOrElse("") - val memoryBytesSpilledReadable = - maybeMemoryBytesSpilled.map(Utils.bytesToString).getOrElse("") - - val maybeDiskBytesSpilled = metrics.map(_.diskBytesSpilled) - val diskBytesSpilledSortable = maybeDiskBytesSpilled.map(_.toString).getOrElse("") - val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("") - - - {info.index} - {info.taskId} - { - if (info.speculative) s"${info.attempt} (speculative)" else info.attempt.toString - } - {info.status} - {info.taskLocality} - {info.executorId} / {info.host} - {UIUtils.formatDate(new Date(info.launchTime))} - - {formatDuration} - - - {UIUtils.formatDuration(schedulerDelay.toLong)} - - - {UIUtils.formatDuration(taskDeserializationTime.toLong)} - - - {if (gcTime > 0) UIUtils.formatDuration(gcTime) else ""} - - - {UIUtils.formatDuration(serializationTime)} - - - {UIUtils.formatDuration(gettingResultTime)} - - {if (hasAccumulators) { - - {Unparsed(accumulatorsReadable.mkString("
    "))} - - }} - {if (hasInput) { - - {s"$inputReadable / $inputRecords"} - - }} - {if (hasOutput) { - - {s"$outputReadable / $outputRecords"} - - }} + private[ui] def getSchedulerDelay( + info: TaskInfo, metrics: TaskMetrics, currentTime: Long): Long = { + if (info.finished) { + val totalExecutionTime = info.finishTime - info.launchTime + val executorOverhead = (metrics.executorDeserializeTime + + metrics.resultSerializationTime) + math.max( + 0, + totalExecutionTime - metrics.executorRunTime - executorOverhead - + getGettingResultTime(info, currentTime)) + } else { + // The task is still running and the metrics like executorRunTime are not available. + 0L + } + } +} + +private[ui] case class TaskTableRowInputData(inputSortable: Long, inputReadable: String) + +private[ui] case class TaskTableRowOutputData(outputSortable: Long, outputReadable: String) + +private[ui] case class TaskTableRowShuffleReadData( + shuffleReadBlockedTimeSortable: Long, + shuffleReadBlockedTimeReadable: String, + shuffleReadSortable: Long, + shuffleReadReadable: String, + shuffleReadRemoteSortable: Long, + shuffleReadRemoteReadable: String) + +private[ui] case class TaskTableRowShuffleWriteData( + writeTimeSortable: Long, + writeTimeReadable: String, + shuffleWriteSortable: Long, + shuffleWriteReadable: String) + +private[ui] case class TaskTableRowBytesSpilledData( + memoryBytesSpilledSortable: Long, + memoryBytesSpilledReadable: String, + diskBytesSpilledSortable: Long, + diskBytesSpilledReadable: String) + +/** + * Contains all data that needs for sorting and generating HTML. Using this one rather than + * TaskUIData to avoid creating duplicate contents during sorting the data. + */ +private[ui] case class TaskTableRowData( + index: Int, + taskId: Long, + attempt: Int, + speculative: Boolean, + status: String, + taskLocality: String, + executorIdAndHost: String, + launchTime: Long, + duration: Long, + formatDuration: String, + schedulerDelay: Long, + taskDeserializationTime: Long, + gcTime: Long, + serializationTime: Long, + gettingResultTime: Long, + accumulators: Option[String], // HTML + input: Option[TaskTableRowInputData], + output: Option[TaskTableRowOutputData], + shuffleRead: Option[TaskTableRowShuffleReadData], + shuffleWrite: Option[TaskTableRowShuffleWriteData], + bytesSpilled: Option[TaskTableRowBytesSpilledData], + error: String) + +private[ui] class TaskDataSource( + tasks: Seq[TaskUIData], + hasAccumulators: Boolean, + hasInput: Boolean, + hasOutput: Boolean, + hasShuffleRead: Boolean, + hasShuffleWrite: Boolean, + hasBytesSpilled: Boolean, + currentTime: Long, + pageSize: Int, + sortColumn: String, + desc: Boolean) extends PagedDataSource[TaskTableRowData](pageSize) { + import StagePage._ + + // Convert TaskUIData to TaskTableRowData which contains the final contents to show in the table + // so that we can avoid creating duplicate contents during sorting the data + private val data = tasks.map(taskRow).sorted(ordering(sortColumn, desc)) + + private var _slicedTaskIds: Set[Long] = null + + override def dataSize: Int = data.size + + override def sliceData(from: Int, to: Int): Seq[TaskTableRowData] = { + val r = data.slice(from, to) + _slicedTaskIds = r.map(_.taskId).toSet + r + } + + def slicedTaskIds: Set[Long] = _slicedTaskIds + + private def taskRow(taskData: TaskUIData): TaskTableRowData = { + val TaskUIData(info, metrics, errorMessage) = taskData + val duration = if (info.status == "RUNNING") info.timeRunning(currentTime) + else metrics.map(_.executorRunTime).getOrElse(1L) + val formatDuration = if (info.status == "RUNNING") UIUtils.formatDuration(duration) + else metrics.map(m => UIUtils.formatDuration(m.executorRunTime)).getOrElse("") + val schedulerDelay = metrics.map(getSchedulerDelay(info, _, currentTime)).getOrElse(0L) + val gcTime = metrics.map(_.jvmGCTime).getOrElse(0L) + val taskDeserializationTime = metrics.map(_.executorDeserializeTime).getOrElse(0L) + val serializationTime = metrics.map(_.resultSerializationTime).getOrElse(0L) + val gettingResultTime = getGettingResultTime(info, currentTime) + + val maybeAccumulators = info.accumulables + val accumulatorsReadable = maybeAccumulators.map { acc => + StringEscapeUtils.escapeHtml4(s"${acc.name}: ${acc.update.get}") + } + + val maybeInput = metrics.flatMap(_.inputMetrics) + val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L) + val inputReadable = maybeInput + .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})") + .getOrElse("") + val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("") + + val maybeOutput = metrics.flatMap(_.outputMetrics) + val outputSortable = maybeOutput.map(_.bytesWritten).getOrElse(0L) + val outputReadable = maybeOutput + .map(m => s"${Utils.bytesToString(m.bytesWritten)}") + .getOrElse("") + val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("") + + val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics) + val shuffleReadBlockedTimeSortable = maybeShuffleRead.map(_.fetchWaitTime).getOrElse(0L) + val shuffleReadBlockedTimeReadable = + maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") + + val totalShuffleBytes = maybeShuffleRead.map(_.totalBytesRead) + val shuffleReadSortable = totalShuffleBytes.getOrElse(0L) + val shuffleReadReadable = totalShuffleBytes.map(Utils.bytesToString).getOrElse("") + val shuffleReadRecords = maybeShuffleRead.map(_.recordsRead.toString).getOrElse("") + + val remoteShuffleBytes = maybeShuffleRead.map(_.remoteBytesRead) + val shuffleReadRemoteSortable = remoteShuffleBytes.getOrElse(0L) + val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") + + val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics) + val shuffleWriteSortable = maybeShuffleWrite.map(_.shuffleBytesWritten).getOrElse(0L) + val shuffleWriteReadable = maybeShuffleWrite + .map(m => s"${Utils.bytesToString(m.shuffleBytesWritten)}").getOrElse("") + val shuffleWriteRecords = maybeShuffleWrite + .map(_.shuffleRecordsWritten.toString).getOrElse("") + + val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.shuffleWriteTime) + val writeTimeSortable = maybeWriteTime.getOrElse(0L) + val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms => + if (ms == 0) "" else UIUtils.formatDuration(ms) + }.getOrElse("") + + val maybeMemoryBytesSpilled = metrics.map(_.memoryBytesSpilled) + val memoryBytesSpilledSortable = maybeMemoryBytesSpilled.getOrElse(0L) + val memoryBytesSpilledReadable = + maybeMemoryBytesSpilled.map(Utils.bytesToString).getOrElse("") + + val maybeDiskBytesSpilled = metrics.map(_.diskBytesSpilled) + val diskBytesSpilledSortable = maybeDiskBytesSpilled.getOrElse(0L) + val diskBytesSpilledReadable = maybeDiskBytesSpilled.map(Utils.bytesToString).getOrElse("") + + val input = + if (hasInput) { + Some(TaskTableRowInputData(inputSortable, s"$inputReadable / $inputRecords")) + } else { + None + } + + val output = + if (hasOutput) { + Some(TaskTableRowOutputData(outputSortable, s"$outputReadable / $outputRecords")) + } else { + None + } + + val shuffleRead = + if (hasShuffleRead) { + Some(TaskTableRowShuffleReadData( + shuffleReadBlockedTimeSortable, + shuffleReadBlockedTimeReadable, + shuffleReadSortable, + s"$shuffleReadReadable / $shuffleReadRecords", + shuffleReadRemoteSortable, + shuffleReadRemoteReadable + )) + } else { + None + } + + val shuffleWrite = + if (hasShuffleWrite) { + Some(TaskTableRowShuffleWriteData( + writeTimeSortable, + writeTimeReadable, + shuffleWriteSortable, + s"$shuffleWriteReadable / $shuffleWriteRecords" + )) + } else { + None + } + + val bytesSpilled = + if (hasBytesSpilled) { + Some(TaskTableRowBytesSpilledData( + memoryBytesSpilledSortable, + memoryBytesSpilledReadable, + diskBytesSpilledSortable, + diskBytesSpilledReadable + )) + } else { + None + } + + TaskTableRowData( + info.index, + info.taskId, + info.attempt, + info.speculative, + info.status, + info.taskLocality.toString, + s"${info.executorId} / ${info.host}", + info.launchTime, + duration, + formatDuration, + schedulerDelay, + taskDeserializationTime, + gcTime, + serializationTime, + gettingResultTime, + if (hasAccumulators) Some(accumulatorsReadable.mkString("
    ")) else None, + input, + output, + shuffleRead, + shuffleWrite, + bytesSpilled, + errorMessage.getOrElse("") + ) + } + + /** + * Return Ordering according to sortColumn and desc + */ + private def ordering(sortColumn: String, desc: Boolean): Ordering[TaskTableRowData] = { + val ordering = sortColumn match { + case "Index" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Int.compare(x.index, y.index) + } + case "ID" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.taskId, y.taskId) + } + case "Attempt" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Int.compare(x.attempt, y.attempt) + } + case "Status" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.status, y.status) + } + case "Locality Level" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.taskLocality, y.taskLocality) + } + case "Executor ID / Host" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.executorIdAndHost, y.executorIdAndHost) + } + case "Launch Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.launchTime, y.launchTime) + } + case "Duration" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.duration, y.duration) + } + case "Scheduler Delay" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.schedulerDelay, y.schedulerDelay) + } + case "Task Deserialization Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.taskDeserializationTime, y.taskDeserializationTime) + } + case "GC Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.gcTime, y.gcTime) + } + case "Result Serialization Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.serializationTime, y.serializationTime) + } + case "Getting Result Time" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.gettingResultTime, y.gettingResultTime) + } + case "Accumulators" => + if (hasAccumulators) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.accumulators.get, y.accumulators.get) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Accumulators because of no accumulators") + } + case "Input Size / Records" => + if (hasInput) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.input.get.inputSortable, y.input.get.inputSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Input Size / Records because of no inputs") + } + case "Output Size / Records" => + if (hasOutput) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.output.get.outputSortable, y.output.get.outputSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Output Size / Records because of no outputs") + } + // ShuffleRead + case "Shuffle Read Blocked Time" => + if (hasShuffleRead) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleRead.get.shuffleReadBlockedTimeSortable, + y.shuffleRead.get.shuffleReadBlockedTimeSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Read Blocked Time because of no shuffle reads") + } + case "Shuffle Read Size / Records" => + if (hasShuffleRead) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleRead.get.shuffleReadSortable, + y.shuffleRead.get.shuffleReadSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Read Size / Records because of no shuffle reads") + } + case "Shuffle Remote Reads" => + if (hasShuffleRead) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleRead.get.shuffleReadRemoteSortable, + y.shuffleRead.get.shuffleReadRemoteSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Remote Reads because of no shuffle reads") + } + // ShuffleWrite + case "Write Time" => + if (hasShuffleWrite) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleWrite.get.writeTimeSortable, + y.shuffleWrite.get.writeTimeSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Write Time because of no shuffle writes") + } + case "Shuffle Write Size / Records" => + if (hasShuffleWrite) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.shuffleWrite.get.shuffleWriteSortable, + y.shuffleWrite.get.shuffleWriteSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Write Size / Records because of no shuffle writes") + } + // BytesSpilled + case "Shuffle Spill (Memory)" => + if (hasBytesSpilled) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.bytesSpilled.get.memoryBytesSpilledSortable, + y.bytesSpilled.get.memoryBytesSpilledSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Spill (Memory) because of no spills") + } + case "Shuffle Spill (Disk)" => + if (hasBytesSpilled) { + new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.Long.compare(x.bytesSpilled.get.diskBytesSpilledSortable, + y.bytesSpilled.get.diskBytesSpilledSortable) + } + } else { + throw new IllegalArgumentException( + "Cannot sort by Shuffle Spill (Disk) because of no spills") + } + case "Errors" => new Ordering[TaskTableRowData] { + override def compare(x: TaskTableRowData, y: TaskTableRowData): Int = + Ordering.String.compare(x.error, y.error) + } + case unknownColumn => throw new IllegalArgumentException(s"Unknown column: $unknownColumn") + } + if (desc) { + ordering.reverse + } else { + ordering + } + } + +} + +private[ui] class TaskPagedTable( + basePath: String, + data: Seq[TaskUIData], + hasAccumulators: Boolean, + hasInput: Boolean, + hasOutput: Boolean, + hasShuffleRead: Boolean, + hasShuffleWrite: Boolean, + hasBytesSpilled: Boolean, + currentTime: Long, + pageSize: Int, + sortColumn: String, + desc: Boolean) extends PagedTable[TaskTableRowData]{ + + override def tableId: String = "" + + override def tableCssClass: String = "table table-bordered table-condensed table-striped" + + override val dataSource: TaskDataSource = new TaskDataSource( + data, + hasAccumulators, + hasInput, + hasOutput, + hasShuffleRead, + hasShuffleWrite, + hasBytesSpilled, + currentTime, + pageSize, + sortColumn, + desc + ) + + override def pageLink(page: Int): String = { + val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") + s"${basePath}&task.page=$page&task.sort=${encodedSortColumn}&task.desc=${desc}" + + s"&task.pageSize=${pageSize}" + } + + override def goButtonJavascriptFunction: (String, String) = { + val jsFuncName = "goToTaskPage" + val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") + val jsFunc = s""" + |currentTaskPageSize = ${pageSize} + |function goToTaskPage(page, pageSize) { + | // Set page to 1 if the page size changes + | page = pageSize == currentTaskPageSize ? page : 1; + | var url = "${basePath}&task.sort=${encodedSortColumn}&task.desc=${desc}" + + | "&task.page=" + page + "&task.pageSize=" + pageSize; + | window.location.href = url; + |} + """.stripMargin + (jsFuncName, jsFunc) + } + + def headers: Seq[Node] = { + val taskHeadersAndCssClasses: Seq[(String, String)] = + Seq( + ("Index", ""), ("ID", ""), ("Attempt", ""), ("Status", ""), ("Locality Level", ""), + ("Executor ID / Host", ""), ("Launch Time", ""), ("Duration", ""), + ("Scheduler Delay", TaskDetailsClassNames.SCHEDULER_DELAY), + ("Task Deserialization Time", TaskDetailsClassNames.TASK_DESERIALIZATION_TIME), + ("GC Time", ""), + ("Result Serialization Time", TaskDetailsClassNames.RESULT_SERIALIZATION_TIME), + ("Getting Result Time", TaskDetailsClassNames.GETTING_RESULT_TIME)) ++ + {if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++ + {if (hasInput) Seq(("Input Size / Records", "")) else Nil} ++ + {if (hasOutput) Seq(("Output Size / Records", "")) else Nil} ++ {if (hasShuffleRead) { - - {shuffleReadBlockedTimeReadable} - - - {s"$shuffleReadReadable / $shuffleReadRecords"} - - - {shuffleReadRemoteReadable} - - }} + Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME), + ("Shuffle Read Size / Records", ""), + ("Shuffle Remote Reads", TaskDetailsClassNames.SHUFFLE_READ_REMOTE_SIZE)) + } else { + Nil + }} ++ {if (hasShuffleWrite) { - - {writeTimeReadable} - - - {s"$shuffleWriteReadable / $shuffleWriteRecords"} - - }} + Seq(("Write Time", ""), ("Shuffle Write Size / Records", "")) + } else { + Nil + }} ++ {if (hasBytesSpilled) { - - {memoryBytesSpilledReadable} - - - {diskBytesSpilledReadable} - - }} - {errorMessageCell(errorMessage)} - + Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", "")) + } else { + Nil + }} ++ + Seq(("Errors", "")) + + if (!taskHeadersAndCssClasses.map(_._1).contains(sortColumn)) { + new IllegalArgumentException(s"Unknown column: $sortColumn") } + + val headerRow: Seq[Node] = { + taskHeadersAndCssClasses.map { case (header, cssClass) => + if (header == sortColumn) { + val headerLink = + s"$basePath&task.sort=${URLEncoder.encode(header, "UTF-8")}&task.desc=${!desc}" + + s"&task.pageSize=${pageSize}" + val js = Unparsed(s"window.location.href='${headerLink}'") + val arrow = if (desc) "▾" else "▴" // UP or DOWN + + {header} +  {Unparsed(arrow)} + + } else { + val headerLink = + s"$basePath&task.sort=${URLEncoder.encode(header, "UTF-8")}&task.pageSize=${pageSize}" + val js = Unparsed(s"window.location.href='${headerLink}'") + + {header} + + } + } + } + {headerRow} + } + + def row(task: TaskTableRowData): Seq[Node] = { + + {task.index} + {task.taskId} + {if (task.speculative) s"${task.attempt} (speculative)" else task.attempt.toString} + {task.status} + {task.taskLocality} + {task.executorIdAndHost} + {UIUtils.formatDate(new Date(task.launchTime))} + {task.formatDuration} + + {UIUtils.formatDuration(task.schedulerDelay)} + + + {UIUtils.formatDuration(task.taskDeserializationTime)} + + + {if (task.gcTime > 0) UIUtils.formatDuration(task.gcTime) else ""} + + + {UIUtils.formatDuration(task.serializationTime)} + + + {UIUtils.formatDuration(task.gettingResultTime)} + + {if (task.accumulators.nonEmpty) { + {Unparsed(task.accumulators.get)} + }} + {if (task.input.nonEmpty) { + {task.input.get.inputReadable} + }} + {if (task.output.nonEmpty) { + {task.output.get.outputReadable} + }} + {if (task.shuffleRead.nonEmpty) { + + {task.shuffleRead.get.shuffleReadBlockedTimeReadable} + + {task.shuffleRead.get.shuffleReadReadable} + + {task.shuffleRead.get.shuffleReadRemoteReadable} + + }} + {if (task.shuffleWrite.nonEmpty) { + {task.shuffleWrite.get.writeTimeReadable} + {task.shuffleWrite.get.shuffleWriteReadable} + }} + {if (task.bytesSpilled.nonEmpty) { + {task.bytesSpilled.get.memoryBytesSpilledReadable} + {task.bytesSpilled.get.diskBytesSpilledReadable} + }} + {errorMessageCell(task.error)} + } - private def errorMessageCell(errorMessage: Option[String]): Seq[Node] = { - val error = errorMessage.getOrElse("") + private def errorMessageCell(error: String): Seq[Node] = { val isMultiline = error.indexOf('\n') >= 0 // Display the first line by default val errorSummary = StringEscapeUtils.escapeHtml4( @@ -860,32 +1325,4 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } {errorSummary}{details} } - - private def getGettingResultTime(info: TaskInfo, currentTime: Long): Long = { - if (info.gettingResult) { - if (info.finished) { - info.finishTime - info.gettingResultTime - } else { - // The task is still fetching the result. - currentTime - info.gettingResultTime - } - } else { - 0L - } - } - - private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics, currentTime: Long): Long = { - if (info.finished) { - val totalExecutionTime = info.finishTime - info.launchTime - val executorOverhead = (metrics.executorDeserializeTime + - metrics.resultSerializationTime) - math.max( - 0, - totalExecutionTime - metrics.executorRunTime - executorOverhead - - getGettingResultTime(info, currentTime)) - } else { - // The task is still running and the metrics like executorRunTime are not available. - 0L - } - } } diff --git a/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala new file mode 100644 index 0000000000000..cc76c141c53cc --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ui + +import scala.xml.Node + +import org.apache.spark.SparkFunSuite + +class PagedDataSourceSuite extends SparkFunSuite { + + test("basic") { + val dataSource1 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + assert(dataSource1.pageData(1) === PageData(3, (1 to 2))) + + val dataSource2 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + assert(dataSource2.pageData(2) === PageData(3, (3 to 4))) + + val dataSource3 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + assert(dataSource3.pageData(3) === PageData(3, Seq(5))) + + val dataSource4 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + val e1 = intercept[IndexOutOfBoundsException] { + dataSource4.pageData(4) + } + assert(e1.getMessage === "Page 4 is out of range. Please select a page number between 1 and 3.") + + val dataSource5 = new SeqPagedDataSource[Int](1 to 5, pageSize = 2) + val e2 = intercept[IndexOutOfBoundsException] { + dataSource5.pageData(0) + } + assert(e2.getMessage === "Page 0 is out of range. Please select a page number between 1 and 3.") + + } +} + +class PagedTableSuite extends SparkFunSuite { + test("pageNavigation") { + // Create a fake PagedTable to test pageNavigation + val pagedTable = new PagedTable[Int] { + override def tableId: String = "" + + override def tableCssClass: String = "" + + override def dataSource: PagedDataSource[Int] = null + + override def pageLink(page: Int): String = page.toString + + override def headers: Seq[Node] = Nil + + override def row(t: Int): Seq[Node] = Nil + + override def goButtonJavascriptFunction: (String, String) = ("", "") + } + + assert(pagedTable.pageNavigation(1, 10, 1) === Nil) + assert( + (pagedTable.pageNavigation(1, 10, 2).head \\ "li").map(_.text.trim) === Seq("1", "2", ">")) + assert( + (pagedTable.pageNavigation(2, 10, 2).head \\ "li").map(_.text.trim) === Seq("<", "1", "2")) + + assert((pagedTable.pageNavigation(1, 10, 100).head \\ "li").map(_.text.trim) === + (1 to 10).map(_.toString) ++ Seq(">", ">>")) + assert((pagedTable.pageNavigation(2, 10, 100).head \\ "li").map(_.text.trim) === + Seq("<") ++ (1 to 10).map(_.toString) ++ Seq(">", ">>")) + + assert((pagedTable.pageNavigation(100, 10, 100).head \\ "li").map(_.text.trim) === + Seq("<<", "<") ++ (91 to 100).map(_.toString)) + assert((pagedTable.pageNavigation(99, 10, 100).head \\ "li").map(_.text.trim) === + Seq("<<", "<") ++ (91 to 100).map(_.toString) ++ Seq(">")) + + assert((pagedTable.pageNavigation(11, 10, 100).head \\ "li").map(_.text.trim) === + Seq("<<", "<") ++ (11 to 20).map(_.toString) ++ Seq(">", ">>")) + assert((pagedTable.pageNavigation(93, 10, 97).head \\ "li").map(_.text.trim) === + Seq("<<", "<") ++ (91 to 97).map(_.toString) ++ Seq(">")) + } +} + +private[spark] class SeqPagedDataSource[T](seq: Seq[T], pageSize: Int) + extends PagedDataSource[T](pageSize) { + + override protected def dataSize: Int = seq.size + + override protected def sliceData(from: Int, to: Int): Seq[T] = seq.slice(from, to) +} From d45355ee224b734727255ff278a47801f5da7e93 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 21 Jul 2015 09:55:42 -0700 Subject: [PATCH 29/32] [SPARK-5423] [CORE] Register a TaskCompletionListener to make sure release all resources Make `DiskMapIterator.cleanup` idempotent and register a TaskCompletionListener to make sure call `cleanup`. Author: zsxwing Closes #7529 from zsxwing/SPARK-5423 and squashes the following commits: 3e3c413 [zsxwing] Remove TODO 9556c78 [zsxwing] Fix NullPointerException for tests 3d574d9 [zsxwing] Register a TaskCompletionListener to make sure release all resources --- .../collection/ExternalAppendOnlyMap.scala | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 1e4531ef395ae..d166037351c31 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.io.ByteStreams -import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.{Logging, SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{BlockId, BlockManager} @@ -470,14 +470,27 @@ class ExternalAppendOnlyMap[K, V, C]( item } - // TODO: Ensure this gets called even if the iterator isn't drained. private def cleanup() { batchIndex = batchOffsets.length // Prevent reading any other batch val ds = deserializeStream - deserializeStream = null - fileStream = null - ds.close() - file.delete() + if (ds != null) { + ds.close() + deserializeStream = null + } + if (fileStream != null) { + fileStream.close() + fileStream = null + } + if (file.exists()) { + file.delete() + } + } + + val context = TaskContext.get() + // context is null in some tests of ExternalAppendOnlyMapSuite because these tests don't run in + // a TaskContext. + if (context != null) { + context.addTaskCompletionListener(context => cleanup()) } } From 7f072c3d5ec50c65d76bd9f28fac124fce96a89e Mon Sep 17 00:00:00 2001 From: Tarek Auel Date: Tue, 21 Jul 2015 09:58:16 -0700 Subject: [PATCH 30/32] [SPARK-9154] [SQL] codegen StringFormat Jira: https://issues.apache.org/jira/browse/SPARK-9154 Author: Tarek Auel Closes #7546 from tarekauel/SPARK-9154 and squashes the following commits: a943d3e [Tarek Auel] [SPARK-9154] implicit input cast, added tests for null, support for null primitives 10b4de8 [Tarek Auel] [SPARK-9154][SQL] codegen removed fallback trait cd8322b [Tarek Auel] [SPARK-9154][SQL] codegen string format 086caba [Tarek Auel] [SPARK-9154][SQL] codegen string format --- .../expressions/stringOperations.scala | 42 ++++++++++++++++++- .../expressions/StringExpressionsSuite.scala | 18 ++++---- .../spark/sql/StringFunctionsSuite.scala | 10 +++++ 3 files changed, 59 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index fe57d17f1ec14..280ae0e546358 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -526,7 +526,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) /** * Returns the input formatted according do printf-style format strings */ -case class StringFormat(children: Expression*) extends Expression with CodegenFallback { +case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes { require(children.nonEmpty, "printf() should take at least 1 argument") @@ -536,6 +536,10 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa private def format: Expression = children(0) private def args: Seq[Expression] = children.tail + override def inputTypes: Seq[AbstractDataType] = + children.zipWithIndex.map(x => if (x._2 == 0) StringType else AnyDataType) + + override def eval(input: InternalRow): Any = { val pattern = format.eval(input) if (pattern == null) { @@ -551,6 +555,42 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa } } + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val pattern = children.head.gen(ctx) + + val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx))) + val argListCode = argListGen.map(_._2.code + "\n") + + val argListString = argListGen.foldLeft("")((s, v) => { + val nullSafeString = + if (ctx.boxedType(v._1) != ctx.javaType(v._1)) { + // Java primitives get boxed in order to allow null values. + s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " + + s"new ${ctx.boxedType(v._1)}(${v._2.primitive})" + } else { + s"(${v._2.isNull}) ? null : ${v._2.primitive}" + } + s + "," + nullSafeString + }) + + val form = ctx.freshName("formatter") + val formatter = classOf[java.util.Formatter].getName + val sb = ctx.freshName("sb") + val stringBuffer = classOf[StringBuffer].getName + s""" + ${pattern.code} + boolean ${ev.isNull} = ${pattern.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${argListCode.mkString} + $stringBuffer $sb = new $stringBuffer(); + $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); + $form.format(${pattern.primitive}.toString() $argListString); + ${ev.primitive} = UTF8String.fromString($sb.toString()); + } + """ + } + override def prettyName: String = "printf" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 96c540ab36f08..3c2d88731beb4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -351,18 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("FORMAT") { - val f = 'f.string.at(0) - val d1 = 'd.int.at(1) - val s1 = 's.int.at(2) - - val row1 = create_row("aa%d%s", 12, "cc") - val row2 = create_row(null, 12, "cc") - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null)) - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") + checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc") - checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1) - checkEvaluation(StringFormat(f, d1, s1), null, row2) + checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null) + checkEvaluation( + StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc") + checkEvaluation( + StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null") } test("INSTR") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index d1f855903ca4b..3702e73b4e74f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -132,6 +132,16 @@ class StringFunctionsSuite extends QueryTest { checkAnswer( df.selectExpr("printf(a, b, c)"), Row("aa123cc")) + + val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c") + + checkAnswer( + df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), + Row("aa123cc", "aa123cc")) + + checkAnswer( + df2.selectExpr("printf(a, b, c)"), + Row("aa123cc")) } test("string instr function") { From 89db3c0b6edcffed7e1e12c202e6827271ddba26 Mon Sep 17 00:00:00 2001 From: MechCoder Date: Tue, 21 Jul 2015 10:31:31 -0700 Subject: [PATCH 31/32] [SPARK-5989] [MLLIB] Model save/load for LDA Add support for saving and loading LDA both the local and distributed versions. Author: MechCoder Closes #6948 from MechCoder/lda_save_load and squashes the following commits: 49bcdce [MechCoder] minor style fixes cc14054 [MechCoder] minor 4587d1d [MechCoder] Minor changes c753122 [MechCoder] Load and save the model in private methods 2782326 [MechCoder] [SPARK-5989] Model save/load for LDA --- docs/mllib-clustering.md | 10 +- .../spark/mllib/clustering/LDAModel.scala | 228 +++++++++++++++++- .../spark/mllib/clustering/LDASuite.scala | 41 ++++ 3 files changed, 274 insertions(+), 5 deletions(-) diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index 0fc7036bffeaf..bb875ae2ae6cb 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -472,7 +472,7 @@ to the algorithm. We then output the topics, represented as probability distribu
    {% highlight scala %} -import org.apache.spark.mllib.clustering.LDA +import org.apache.spark.mllib.clustering.{LDA, DistributedLDAModel} import org.apache.spark.mllib.linalg.Vectors // Load and parse the data @@ -492,6 +492,11 @@ for (topic <- Range(0, 3)) { for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); } println() } + +// Save and load model. +ldaModel.save(sc, "myLDAModel") +val sameModel = DistributedLDAModel.load(sc, "myLDAModel") + {% endhighlight %}
    @@ -551,6 +556,9 @@ public class JavaLDAExample { } System.out.println(); } + + ldaModel.save(sc.sc(), "myLDAModel"); + DistributedLDAModel sameModel = DistributedLDAModel.load(sc.sc(), "myLDAModel"); } } {% endhighlight %} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 974b26924dfb8..920b57756b625 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,15 +17,25 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum} +import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum, DenseVector => BDV} +import org.apache.hadoop.fs.Path + +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.SparkContext import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaPairRDD -import org.apache.spark.graphx.{VertexId, EdgeContext, Graph} -import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} +import org.apache.spark.graphx.{VertexId, Edge, EdgeContext, Graph} +import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix, DenseVector} +import org.apache.spark.mllib.util.{Saveable, Loader} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, Row} import org.apache.spark.util.BoundedPriorityQueue + /** * :: Experimental :: * @@ -35,7 +45,7 @@ import org.apache.spark.util.BoundedPriorityQueue * including local and distributed data structures. */ @Experimental -abstract class LDAModel private[clustering] { +abstract class LDAModel private[clustering] extends Saveable { /** Number of topics */ def k: Int @@ -176,6 +186,11 @@ class LocalLDAModel private[clustering] ( }.toArray } + override protected def formatVersion = "1.0" + + override def save(sc: SparkContext, path: String): Unit = { + LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix) + } // TODO // override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ??? @@ -184,6 +199,80 @@ class LocalLDAModel private[clustering] ( } +@Experimental +object LocalLDAModel extends Loader[LocalLDAModel] { + + private object SaveLoadV1_0 { + + val thisFormatVersion = "1.0" + + val thisClassName = "org.apache.spark.mllib.clustering.LocalLDAModel" + + // Store the distribution of terms of each topic and the column index in topicsMatrix + // as a Row in data. + case class Data(topic: Vector, index: Int) + + def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = { + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + + val k = topicsMatrix.numCols + val metadata = compact(render + (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ + ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix + val topics = Range(0, k).map { topicInd => + Data(Vectors.dense((topicsDenseMatrix(::, topicInd).toArray)), topicInd) + }.toSeq + sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): LocalLDAModel = { + val dataPath = Loader.dataPath(path) + val sqlContext = SQLContext.getOrCreate(sc) + val dataFrame = sqlContext.read.parquet(dataPath) + + Loader.checkSchema[Data](dataFrame.schema) + val topics = dataFrame.collect() + val vocabSize = topics(0).getAs[Vector](0).size + val k = topics.size + + val brzTopics = BDM.zeros[Double](vocabSize, k) + topics.foreach { case Row(vec: Vector, ind: Int) => + brzTopics(::, ind) := vec.toBreeze + } + new LocalLDAModel(Matrices.fromBreeze(brzTopics)) + } + } + + override def load(sc: SparkContext, path: String): LocalLDAModel = { + val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val expectedK = (metadata \ "k").extract[Int] + val expectedVocabSize = (metadata \ "vocabSize").extract[Int] + val classNameV1_0 = SaveLoadV1_0.thisClassName + + val model = (loadedClassName, loadedVersion) match { + case (className, "1.0") if className == classNameV1_0 => + SaveLoadV1_0.load(sc, path) + case _ => throw new Exception( + s"LocalLDAModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $loadedVersion). Supported:\n" + + s" ($classNameV1_0, 1.0)") + } + + val topicsMatrix = model.topicsMatrix + require(expectedK == topicsMatrix.numCols, + s"LocalLDAModel requires $expectedK topics, got ${topicsMatrix.numCols} topics") + require(expectedVocabSize == topicsMatrix.numRows, + s"LocalLDAModel requires $expectedVocabSize terms for each topic, " + + s"but got ${topicsMatrix.numRows}") + model + } +} + /** * :: Experimental :: * @@ -354,4 +443,135 @@ class DistributedLDAModel private ( // TODO: // override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ??? + override protected def formatVersion = "1.0" + + override def save(sc: SparkContext, path: String): Unit = { + DistributedLDAModel.SaveLoadV1_0.save( + sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration, + iterationTimes) + } +} + + +@Experimental +object DistributedLDAModel extends Loader[DistributedLDAModel] { + + private object SaveLoadV1_0 { + + val thisFormatVersion = "1.0" + + val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel" + + // Store globalTopicTotals as a Vector. + case class Data(globalTopicTotals: Vector) + + // Store each term and document vertex with an id and the topicWeights. + case class VertexData(id: Long, topicWeights: Vector) + + // Store each edge with the source id, destination id and tokenCounts. + case class EdgeData(srcId: Long, dstId: Long, tokenCounts: Double) + + def save( + sc: SparkContext, + path: String, + graph: Graph[LDA.TopicCounts, LDA.TokenCount], + globalTopicTotals: LDA.TopicCounts, + k: Int, + vocabSize: Int, + docConcentration: Double, + topicConcentration: Double, + iterationTimes: Array[Double]): Unit = { + val sqlContext = SQLContext.getOrCreate(sc) + import sqlContext.implicits._ + + val metadata = compact(render + (("class" -> classNameV1_0) ~ ("version" -> thisFormatVersion) ~ + ("k" -> k) ~ ("vocabSize" -> vocabSize) ~ ("docConcentration" -> docConcentration) ~ + ("topicConcentration" -> topicConcentration) ~ + ("iterationTimes" -> iterationTimes.toSeq))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString + sc.parallelize(Seq(Data(Vectors.fromBreeze(globalTopicTotals)))).toDF() + .write.parquet(newPath) + + val verticesPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString + graph.vertices.map { case (ind, vertex) => + VertexData(ind, Vectors.fromBreeze(vertex)) + }.toDF().write.parquet(verticesPath) + + val edgesPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString + graph.edges.map { case Edge(srcId, dstId, prop) => + EdgeData(srcId, dstId, prop) + }.toDF().write.parquet(edgesPath) + } + + def load( + sc: SparkContext, + path: String, + vocabSize: Int, + docConcentration: Double, + topicConcentration: Double, + iterationTimes: Array[Double]): DistributedLDAModel = { + val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString + val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString + val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString + val sqlContext = SQLContext.getOrCreate(sc) + val dataFrame = sqlContext.read.parquet(dataPath) + val vertexDataFrame = sqlContext.read.parquet(vertexDataPath) + val edgeDataFrame = sqlContext.read.parquet(edgeDataPath) + + Loader.checkSchema[Data](dataFrame.schema) + Loader.checkSchema[VertexData](vertexDataFrame.schema) + Loader.checkSchema[EdgeData](edgeDataFrame.schema) + val globalTopicTotals: LDA.TopicCounts = + dataFrame.first().getAs[Vector](0).toBreeze.toDenseVector + val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.map { + case Row(ind: Long, vec: Vector) => (ind, vec.toBreeze.toDenseVector) + } + + val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.map { + case Row(srcId: Long, dstId: Long, prop: Double) => Edge(srcId, dstId, prop) + } + val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges) + + new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize, + docConcentration, topicConcentration, iterationTimes) + } + + } + + override def load(sc: SparkContext, path: String): DistributedLDAModel = { + val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path) + implicit val formats = DefaultFormats + val expectedK = (metadata \ "k").extract[Int] + val vocabSize = (metadata \ "vocabSize").extract[Int] + val docConcentration = (metadata \ "docConcentration").extract[Double] + val topicConcentration = (metadata \ "topicConcentration").extract[Double] + val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]] + val classNameV1_0 = SaveLoadV1_0.classNameV1_0 + + val model = (loadedClassName, loadedVersion) match { + case (className, "1.0") if className == classNameV1_0 => { + DistributedLDAModel.SaveLoadV1_0.load( + sc, path, vocabSize, docConcentration, topicConcentration, iterationTimes.toArray) + } + case _ => throw new Exception( + s"DistributedLDAModel.load did not recognize model with (className, format version):" + + s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)") + } + + require(model.vocabSize == vocabSize, + s"DistributedLDAModel requires $vocabSize vocabSize, got ${model.vocabSize} vocabSize") + require(model.docConcentration == docConcentration, + s"DistributedLDAModel requires $docConcentration docConcentration, " + + s"got ${model.docConcentration} docConcentration") + require(model.topicConcentration == topicConcentration, + s"DistributedLDAModel requires $topicConcentration docConcentration, " + + s"got ${model.topicConcentration} docConcentration") + require(expectedK == model.k, + s"DistributedLDAModel requires $expectedK topics, got ${model.k} topics") + model + } + } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 03a8a2538b464..721a065658951 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils class LDASuite extends SparkFunSuite with MLlibTestSparkContext { @@ -217,6 +218,46 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("model save/load") { + // Test for LocalLDAModel. + val localModel = new LocalLDAModel(tinyTopics) + val tempDir1 = Utils.createTempDir() + val path1 = tempDir1.toURI.toString + + // Test for DistributedLDAModel. + val k = 3 + val docConcentration = 1.2 + val topicConcentration = 1.5 + val lda = new LDA() + lda.setK(k) + .setDocConcentration(docConcentration) + .setTopicConcentration(topicConcentration) + .setMaxIterations(5) + .setSeed(12345) + val corpus = sc.parallelize(tinyCorpus, 2) + val distributedModel: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel] + val tempDir2 = Utils.createTempDir() + val path2 = tempDir2.toURI.toString + + try { + localModel.save(sc, path1) + distributedModel.save(sc, path2) + val samelocalModel = LocalLDAModel.load(sc, path1) + assert(samelocalModel.topicsMatrix === localModel.topicsMatrix) + assert(samelocalModel.k === localModel.k) + assert(samelocalModel.vocabSize === localModel.vocabSize) + + val sameDistributedModel = DistributedLDAModel.load(sc, path2) + assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix) + assert(distributedModel.k === sameDistributedModel.k) + assert(distributedModel.vocabSize === sameDistributedModel.vocabSize) + assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes) + } finally { + Utils.deleteRecursively(tempDir1) + Utils.deleteRecursively(tempDir2) + } + } + } private[clustering] object LDASuite { From 87d890cc105a7f41478433b28f53c9aa431db211 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 21 Jul 2015 11:18:39 -0700 Subject: [PATCH 32/32] Revert "[SPARK-9154] [SQL] codegen StringFormat" This reverts commit 7f072c3d5ec50c65d76bd9f28fac124fce96a89e. Revert #7546 Author: Michael Armbrust Closes #7570 from marmbrus/revert9154 and squashes the following commits: ed2c32a [Michael Armbrust] Revert "[SPARK-9154] [SQL] codegen StringFormat" --- .../expressions/stringOperations.scala | 42 +------------------ .../expressions/StringExpressionsSuite.scala | 18 ++++---- .../spark/sql/StringFunctionsSuite.scala | 10 ----- 3 files changed, 11 insertions(+), 59 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 280ae0e546358..fe57d17f1ec14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -526,7 +526,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) /** * Returns the input formatted according do printf-style format strings */ -case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes { +case class StringFormat(children: Expression*) extends Expression with CodegenFallback { require(children.nonEmpty, "printf() should take at least 1 argument") @@ -536,10 +536,6 @@ case class StringFormat(children: Expression*) extends Expression with ImplicitC private def format: Expression = children(0) private def args: Seq[Expression] = children.tail - override def inputTypes: Seq[AbstractDataType] = - children.zipWithIndex.map(x => if (x._2 == 0) StringType else AnyDataType) - - override def eval(input: InternalRow): Any = { val pattern = format.eval(input) if (pattern == null) { @@ -555,42 +551,6 @@ case class StringFormat(children: Expression*) extends Expression with ImplicitC } } - override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val pattern = children.head.gen(ctx) - - val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx))) - val argListCode = argListGen.map(_._2.code + "\n") - - val argListString = argListGen.foldLeft("")((s, v) => { - val nullSafeString = - if (ctx.boxedType(v._1) != ctx.javaType(v._1)) { - // Java primitives get boxed in order to allow null values. - s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " + - s"new ${ctx.boxedType(v._1)}(${v._2.primitive})" - } else { - s"(${v._2.isNull}) ? null : ${v._2.primitive}" - } - s + "," + nullSafeString - }) - - val form = ctx.freshName("formatter") - val formatter = classOf[java.util.Formatter].getName - val sb = ctx.freshName("sb") - val stringBuffer = classOf[StringBuffer].getName - s""" - ${pattern.code} - boolean ${ev.isNull} = ${pattern.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${argListCode.mkString} - $stringBuffer $sb = new $stringBuffer(); - $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); - $form.format(${pattern.primitive}.toString() $argListString); - ${ev.primitive} = UTF8String.fromString($sb.toString()); - } - """ - } - override def prettyName: String = "printf" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 3c2d88731beb4..96c540ab36f08 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -351,16 +351,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("FORMAT") { - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") + val f = 'f.string.at(0) + val d1 = 'd.int.at(1) + val s1 = 's.int.at(2) + + val row1 = create_row("aa%d%s", 12, "cc") + val row2 = create_row(null, 12, "cc") + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null)) - checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a") - checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc") + checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1) - checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null) - checkEvaluation( - StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc") - checkEvaluation( - StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null") + checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1) + checkEvaluation(StringFormat(f, d1, s1), null, row2) } test("INSTR") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 3702e73b4e74f..d1f855903ca4b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -132,16 +132,6 @@ class StringFunctionsSuite extends QueryTest { checkAnswer( df.selectExpr("printf(a, b, c)"), Row("aa123cc")) - - val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c") - - checkAnswer( - df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")), - Row("aa123cc", "aa123cc")) - - checkAnswer( - df2.selectExpr("printf(a, b, c)"), - Row("aa123cc")) } test("string instr function") {