Skip to content

Commit

Permalink
[SPARK-13781][SQL] Use ExpressionSets in ConstraintPropagationSuite
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

This PR is a small follow up on apache#11338 (https://issues.apache.org/jira/browse/SPARK-13092) to use `ExpressionSet` as part of the verification logic in `ConstraintPropagationSuite`.
## How was this patch tested?

No new tests added. Just changes the verification logic in `ConstraintPropagationSuite`.

Author: Sameer Agarwal <sameer@databricks.com>

Closes apache#11611 from sameeragarwal/expression-set.
  • Loading branch information
sameeragarwal authored and roygao94 committed Mar 22, 2016
1 parent ac4156b commit 363d879
Showing 1 changed file with 25 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ class ConstraintPropagationSuite extends SparkFunSuite {
private def resolveColumn(plan: LogicalPlan, columnName: String): Expression =
plan.resolveQuoted(columnName, caseInsensitiveResolution).get

private def verifyConstraints(found: Set[Expression], expected: Set[Expression]): Unit = {
val missing = expected.filterNot(i => found.map(_.semanticEquals(i)).reduce(_ || _))
val extra = found.filterNot(i => expected.map(_.semanticEquals(i)).reduce(_ || _))
private def verifyConstraints(found: ExpressionSet, expected: ExpressionSet): Unit = {
val missing = expected -- found
val extra = found -- expected
if (missing.nonEmpty || extra.nonEmpty) {
fail(
s"""
Expand All @@ -58,18 +58,18 @@ class ConstraintPropagationSuite extends SparkFunSuite {
verifyConstraints(tr
.where('a.attr > 10)
.analyze.constraints,
Set(resolveColumn(tr, "a") > 10,
IsNotNull(resolveColumn(tr, "a"))))
ExpressionSet(Seq(resolveColumn(tr, "a") > 10,
IsNotNull(resolveColumn(tr, "a")))))

verifyConstraints(tr
.where('a.attr > 10)
.select('c.attr, 'a.attr)
.where('c.attr < 100)
.analyze.constraints,
Set(resolveColumn(tr, "a") > 10,
ExpressionSet(Seq(resolveColumn(tr, "a") > 10,
resolveColumn(tr, "c") < 100,
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "c"))))
IsNotNull(resolveColumn(tr, "c")))))
}

test("propagating constraints in aggregate") {
Expand All @@ -81,10 +81,10 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a).analyze

verifyConstraints(aliasedRelation.analyze.constraints,
Set(resolveColumn(aliasedRelation.analyze, "c1") > 10,
ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "c1") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "c1")),
resolveColumn(aliasedRelation.analyze, "a") < 5,
IsNotNull(resolveColumn(aliasedRelation.analyze, "a"))))
IsNotNull(resolveColumn(aliasedRelation.analyze, "a")))))
}

test("propagating constraints in aliases") {
Expand All @@ -95,11 +95,11 @@ class ConstraintPropagationSuite extends SparkFunSuite {
val aliasedRelation = tr.where('a.attr > 10).select('a.as('x), 'b, 'b.as('y), 'a.as('z))

verifyConstraints(aliasedRelation.analyze.constraints,
Set(resolveColumn(aliasedRelation.analyze, "x") > 10,
ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "x")),
resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"),
resolveColumn(aliasedRelation.analyze, "z") > 10,
IsNotNull(resolveColumn(aliasedRelation.analyze, "z"))))
IsNotNull(resolveColumn(aliasedRelation.analyze, "z")))))
}

test("propagating constraints in union") {
Expand All @@ -118,8 +118,8 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.unionAll(tr2.where('d.attr > 10)
.unionAll(tr3.where('g.attr > 10)))
.analyze.constraints,
Set(resolveColumn(tr1, "a") > 10,
IsNotNull(resolveColumn(tr1, "a"))))
ExpressionSet(Seq(resolveColumn(tr1, "a") > 10,
IsNotNull(resolveColumn(tr1, "a")))))
}

test("propagating constraints in intersect") {
Expand All @@ -130,10 +130,10 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.where('a.attr > 10)
.intersect(tr2.where('b.attr < 100))
.analyze.constraints,
Set(resolveColumn(tr1, "a") > 10,
ExpressionSet(Seq(resolveColumn(tr1, "a") > 10,
resolveColumn(tr1, "b") < 100,
IsNotNull(resolveColumn(tr1, "a")),
IsNotNull(resolveColumn(tr1, "b"))))
IsNotNull(resolveColumn(tr1, "b")))))
}

test("propagating constraints in except") {
Expand All @@ -143,8 +143,8 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.where('a.attr > 10)
.except(tr2.where('b.attr < 100))
.analyze.constraints,
Set(resolveColumn(tr1, "a") > 10,
IsNotNull(resolveColumn(tr1, "a"))))
ExpressionSet(Seq(resolveColumn(tr1, "a") > 10,
IsNotNull(resolveColumn(tr1, "a")))))
}

test("propagating constraints in inner join") {
Expand All @@ -154,13 +154,13 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.where('a.attr > 10)
.join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr))
.analyze.constraints,
Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
tr1.resolveQuoted("a", caseInsensitiveResolution).get ===
tr2.resolveQuoted("a", caseInsensitiveResolution).get,
IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get),
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get),
IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))
IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))))
}

test("propagating constraints in left-semi join") {
Expand All @@ -170,8 +170,8 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.where('a.attr > 10)
.join(tr2.where('d.attr < 100), LeftSemi, Some("tr1.a".attr === "tr2.a".attr))
.analyze.constraints,
Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))
ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))))
}

test("propagating constraints in left-outer join") {
Expand All @@ -181,8 +181,8 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.where('a.attr > 10)
.join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr))
.analyze.constraints,
Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))
ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10,
IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))))
}

test("propagating constraints in right-outer join") {
Expand All @@ -192,8 +192,8 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.where('a.attr > 10)
.join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr))
.analyze.constraints,
Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))
ExpressionSet(Seq(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100,
IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))))
}

test("propagating constraints in full-outer join") {
Expand Down

0 comments on commit 363d879

Please sign in to comment.