diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index b68432b1a128f..868ad934daf17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -32,9 +32,9 @@ class ConstraintPropagationSuite extends SparkFunSuite { private def resolveColumn(plan: LogicalPlan, columnName: String): Expression = plan.resolveQuoted(columnName, caseInsensitiveResolution).get - private def verifyConstraints(found: Set[Expression], expected: Set[Expression]): Unit = { - val missing = expected.filterNot(i => found.map(_.semanticEquals(i)).reduce(_ || _)) - val extra = found.filterNot(i => expected.map(_.semanticEquals(i)).reduce(_ || _)) + private def verifyConstraints(found: ExpressionSet, expected: ExpressionSet): Unit = { + val missing = expected -- found + val extra = found -- expected if (missing.nonEmpty || extra.nonEmpty) { fail( s""" @@ -58,18 +58,18 @@ class ConstraintPropagationSuite extends SparkFunSuite { verifyConstraints(tr .where('a.attr > 10) .analyze.constraints, - Set(resolveColumn(tr, "a") > 10, - IsNotNull(resolveColumn(tr, "a")))) + ExpressionSet(Seq(resolveColumn(tr, "a") > 10, + IsNotNull(resolveColumn(tr, "a"))))) verifyConstraints(tr .where('a.attr > 10) .select('c.attr, 'a.attr) .where('c.attr < 100) .analyze.constraints, - Set(resolveColumn(tr, "a") > 10, + ExpressionSet(Seq(resolveColumn(tr, "a") > 10, resolveColumn(tr, "c") < 100, IsNotNull(resolveColumn(tr, "a")), - IsNotNull(resolveColumn(tr, "c")))) + IsNotNull(resolveColumn(tr, "c"))))) } test("propagating constraints in aggregate") { @@ -81,10 +81,10 @@ class ConstraintPropagationSuite extends SparkFunSuite { .groupBy('a, 'c, 'b)('a, 'c.as("c1"), count('a).as("a3")).select('c1, 'a).analyze verifyConstraints(aliasedRelation.analyze.constraints, - Set(resolveColumn(aliasedRelation.analyze, "c1") > 10, + ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "c1") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "c1")), resolveColumn(aliasedRelation.analyze, "a") < 5, - IsNotNull(resolveColumn(aliasedRelation.analyze, "a")))) + IsNotNull(resolveColumn(aliasedRelation.analyze, "a"))))) } test("propagating constraints in aliases") { @@ -95,11 +95,11 @@ class ConstraintPropagationSuite extends SparkFunSuite { val aliasedRelation = tr.where('a.attr > 10).select('a.as('x), 'b, 'b.as('y), 'a.as('z)) verifyConstraints(aliasedRelation.analyze.constraints, - Set(resolveColumn(aliasedRelation.analyze, "x") > 10, + ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10, IsNotNull(resolveColumn(aliasedRelation.analyze, "x")), resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"), resolveColumn(aliasedRelation.analyze, "z") > 10, - IsNotNull(resolveColumn(aliasedRelation.analyze, "z")))) + IsNotNull(resolveColumn(aliasedRelation.analyze, "z"))))) } test("propagating constraints in union") { @@ -118,8 +118,8 @@ class ConstraintPropagationSuite extends SparkFunSuite { .unionAll(tr2.where('d.attr > 10) .unionAll(tr3.where('g.attr > 10))) .analyze.constraints, - Set(resolveColumn(tr1, "a") > 10, - IsNotNull(resolveColumn(tr1, "a")))) + ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, + IsNotNull(resolveColumn(tr1, "a"))))) } test("propagating constraints in intersect") { @@ -130,10 +130,10 @@ class ConstraintPropagationSuite extends SparkFunSuite { .where('a.attr > 10) .intersect(tr2.where('b.attr < 100)) .analyze.constraints, - Set(resolveColumn(tr1, "a") > 10, + ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, resolveColumn(tr1, "b") < 100, IsNotNull(resolveColumn(tr1, "a")), - IsNotNull(resolveColumn(tr1, "b")))) + IsNotNull(resolveColumn(tr1, "b"))))) } test("propagating constraints in except") { @@ -143,8 +143,8 @@ class ConstraintPropagationSuite extends SparkFunSuite { .where('a.attr > 10) .except(tr2.where('b.attr < 100)) .analyze.constraints, - Set(resolveColumn(tr1, "a") > 10, - IsNotNull(resolveColumn(tr1, "a")))) + ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, + IsNotNull(resolveColumn(tr1, "a"))))) } test("propagating constraints in inner join") { @@ -154,13 +154,13 @@ class ConstraintPropagationSuite extends SparkFunSuite { .where('a.attr > 10) .join(tr2.where('d.attr < 100), Inner, Some("tr1.a".attr === "tr2.a".attr)) .analyze.constraints, - Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, tr1.resolveQuoted("a", caseInsensitiveResolution).get === tr2.resolveQuoted("a", caseInsensitiveResolution).get, IsNotNull(tr2.resolveQuoted("a", caseInsensitiveResolution).get), IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get), - IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))) + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))) } test("propagating constraints in left-semi join") { @@ -170,8 +170,8 @@ class ConstraintPropagationSuite extends SparkFunSuite { .where('a.attr > 10) .join(tr2.where('d.attr < 100), LeftSemi, Some("tr1.a".attr === "tr2.a".attr)) .analyze.constraints, - Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, - IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))) } test("propagating constraints in left-outer join") { @@ -181,8 +181,8 @@ class ConstraintPropagationSuite extends SparkFunSuite { .where('a.attr > 10) .join(tr2.where('d.attr < 100), LeftOuter, Some("tr1.a".attr === "tr2.a".attr)) .analyze.constraints, - Set(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, - IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get))) + ExpressionSet(Seq(tr1.resolveQuoted("a", caseInsensitiveResolution).get > 10, + IsNotNull(tr1.resolveQuoted("a", caseInsensitiveResolution).get)))) } test("propagating constraints in right-outer join") { @@ -192,8 +192,8 @@ class ConstraintPropagationSuite extends SparkFunSuite { .where('a.attr > 10) .join(tr2.where('d.attr < 100), RightOuter, Some("tr1.a".attr === "tr2.a".attr)) .analyze.constraints, - Set(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, - IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get))) + ExpressionSet(Seq(tr2.resolveQuoted("d", caseInsensitiveResolution).get < 100, + IsNotNull(tr2.resolveQuoted("d", caseInsensitiveResolution).get)))) } test("propagating constraints in full-outer join") {