Skip to content

Commit

Permalink
airframe-sql: Intersect should have attributes of all source relations
Browse files Browse the repository at this point in the history
  • Loading branch information
takezoe committed Nov 25, 2022
1 parent fda8293 commit 33205aa
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ object TypeResolver extends LogSupport {
TypeResolver.resolveRegularRelation _ ::
TypeResolver.resolveColumns _ ::
TypeResolver.resolveUnion _ ::
TypeResolver.resolveIntersect _ ::
Nil
}

Expand Down Expand Up @@ -177,7 +178,8 @@ object TypeResolver extends LogSupport {
def resolveRegularRelation(context: AnalyzerContext): PlanRewriter = {
case filter @ Filter(child, filterExpr, _) =>
filter.transformExpressions { case x: Expression => resolveExpression(context, x, filter.inputAttributes) }
case u: Union => u // UNION is resolved later by resolveUnion()
case u: Union => u // UNION is resolved later by resolveUnion()
case u: Intersect => u // INTERSECT is resolved later by resolveIntersect()
case r: Relation =>
r.transformExpressions { case x: Expression => resolveExpression(context, x, r.inputAttributes) }
}
Expand All @@ -188,6 +190,18 @@ object TypeResolver extends LogSupport {
resolved
}

def resolveIntersect(context: AnalyzerContext): PlanRewriter = {
case u@Intersect(_, None, _) =>
val resolvedOutputs = u.outputAttributes.collect { case SingleColumn(UnionColumn(inputs, _), _, _, _) =>
val resolved = inputs
.map { expr =>
resolveExpression(context, expr, u.inputAttributes)
}.collect { case a: ResolvedAttribute => a }
resolved.head.copy(sourceColumns = resolved.flatMap(_.sourceColumns))
}
u.copy(resolvedOutputs = Some(resolvedOutputs))
}

def resolveUnion(context: AnalyzerContext): PlanRewriter = { case u @ Union(_, None, _) =>
val resolvedOutputs = u.outputAttributes.collect { case SingleColumn(UnionColumn(inputs, _), _, _, _) =>
val resolved = inputs
Expand Down Expand Up @@ -296,6 +310,7 @@ object TypeResolver extends LogSupport {
def resolveExpression(context: AnalyzerContext, expr: Expression, inputAttributes: Seq[Attribute]): Expression = {
findMatchInInputAttributes(context, expr, inputAttributes) match {
case lst if lst.length > 1 =>
println(expr + " -> " + lst)
throw SQLErrorCode.SyntaxError.newException(s"${expr.sqlExpr} is ambiguous", expr.nodeLocation)
case lst =>
lst.headOption.getOrElse(expr)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -601,15 +601,27 @@ object LogicalPlan {
sealed trait SetOperation extends Relation {
override def children: Seq[Relation]
}
case class Intersect(relations: Seq[Relation], nodeLocation: Option[NodeLocation]) extends SetOperation {
case class Intersect(relations: Seq[Relation], resolvedOutputs: Option[Seq[Attribute]], nodeLocation: Option[NodeLocation]) extends SetOperation {
override def children: Seq[Relation] = relations
override def sig(config: QuerySignatureConfig): String = {
s"IX(${relations.map(_.sig(config)).mkString(",")})"
}
override def inputAttributes: Seq[Attribute] =
relations.head.inputAttributes
override def outputAttributes: Seq[Attribute] =
relations.head.outputAttributes
relations.flatMap(_.inputAttributes)
override def outputAttributes: Seq[Attribute] = {
val out = resolvedOutputs.getOrElse {
relations.head.outputAttributes.zipWithIndex.map { case (output, i) =>
SingleColumn(
UnionColumn(relations.map(_.outputAttributes(i)), output.nodeLocation),
None,
None,
output.nodeLocation
)
}
}
println("** out: " + out)
out
}
}
case class Except(left: Relation, right: Relation, nodeLocation: Option[NodeLocation]) extends SetOperation {
override def children: Seq[Relation] = Seq(left, right)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ object SQLGenerator extends LogSupport {
if (isDistinct) "UNION" else "UNION ALL"
case Except(left, right, _) =>
if (isDistinct) "EXCEPT" else "EXCEPT ALL"
case Intersect(relations, _) =>
case Intersect(relations, _, _) =>
if (isDistinct) "INTERSECT" else "INTERSECT ALL"
}
s.children.map(printRelation).mkString(s" ${op} ")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class SQLInterpreter(withNodeLocation: Boolean = true) extends SqlBaseBaseVisito
.map(visitSetQuantifier(_).isDistinct)
.getOrElse(true)
val base = if (ctx.INTERSECT() != null) {
Intersect(children, getLocation(ctx.INTERSECT()))
Intersect(children, None, getLocation(ctx.INTERSECT()))
} else if (ctx.UNION() != null) {
Union(children, None, getLocation(ctx.UNION()))
} else if (ctx.EXCEPT() != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import wvlet.airframe.sql.analyzer.SQLAnalyzer.PlanRewriter
import wvlet.airframe.sql.catalog.Catalog._
import wvlet.airframe.sql.catalog.{Catalog, DataType, InMemoryCatalog}
import wvlet.airframe.sql.model.Expression._
import wvlet.airframe.sql.model.LogicalPlan.{Aggregate, Filter, Project}
import wvlet.airframe.sql.model.LogicalPlan.{Aggregate, Distinct, Filter, Intersect, Project}
import wvlet.airframe.sql.model.{Expression, LogicalPlan, NodeLocation, ResolvedAttribute, SourceColumn}
import wvlet.airframe.sql.parser.SQLParser
import wvlet.airframe.sql.{SQLError, SQLErrorCode}
Expand Down Expand Up @@ -170,6 +170,18 @@ class TypeResolverTest extends AirSpec {
None
)
}

test("resolve intersect") {
val p = analyze("select id from A intersect select id from B") // => Distinct(Intersect(...))
p match {
case Distinct(i @ Intersect(_, _, _), _) =>
i.inputAttributes shouldBe List(ra1, ra2, rb1, rb2)
i.outputAttributes shouldBe List(
ResolvedAttribute("id", DataType.LongType, None, ra1.sourceColumns ++ rb1.sourceColumns, None)
)
case _ => fail(s"unexpected plan:\n${p.pp}")
}
}
}

test("resolve aggregation queries") {
Expand Down

0 comments on commit 33205aa

Please sign in to comment.