Skip to content

Commit

Permalink
airframe-sql: Resolve joins (#2389)
Browse files Browse the repository at this point in the history
* Resolve join using(...)
* Remove ambiguous join keys
* Add using join test
* Add join on support
* Add join key rename test
* Resolve join using different name columns
  • Loading branch information
xerial committed Sep 7, 2022
1 parent 23c1b3a commit f874b89
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 29 deletions.
Expand Up @@ -32,12 +32,7 @@ object CTEResolver extends LogSupport {

var currentContext = analyzerContext
val resolvedQueries = queryDefs.map { x =>
val resolvedQuery: Relation = TypeResolver.resolve(currentContext, x.query) match {
case r: Relation => r
case other =>
// This should not happen in general
x.query
}
val resolvedQuery: Relation = TypeResolver.resolveRelation(currentContext, x.query)
val cteBody = x.columnNames match {
case None =>
resolvedQuery
Expand All @@ -58,7 +53,7 @@ object CTEResolver extends LogSupport {
// cteBody already has renaming with projection, no need to propagate column name aliases
WithQuery(x.name, cteBody, None)
}
val newBody = TypeResolver.resolve(currentContext, body).asInstanceOf[Relation]
val newBody = TypeResolver.resolveRelation(currentContext, body)
Query(With(recursive, resolvedQueries), newBody)
}
}
Expand Down
Expand Up @@ -15,7 +15,7 @@ package wvlet.airframe.sql.analyzer
import wvlet.airframe.sql.SQLErrorCode
import wvlet.airframe.sql.analyzer.SQLAnalyzer.{PlanRewriter, Rule}
import wvlet.airframe.sql.model.Expression._
import wvlet.airframe.sql.model.LogicalPlan.{Aggregate, Filter, Project, Query, Relation, Union}
import wvlet.airframe.sql.model.LogicalPlan.{Aggregate, Filter, Join, Project, Query, Relation, Union}
import wvlet.airframe.sql.model._
import wvlet.log.LogSupport

Expand All @@ -29,7 +29,8 @@ object TypeResolver extends LogSupport {
TypeResolver.resolveAggregationIndexes _ ::
TypeResolver.resolveCTETableRef _ ::
TypeResolver.resolveTableRef _ ::
TypeResolver.resolveRelation _ ::
TypeResolver.resolveJoinUsing _ ::
TypeResolver.resolveRegularRelation _ ::
TypeResolver.resolveColumns _ ::
TypeResolver.resolveUnion _ ::
Nil
Expand All @@ -39,11 +40,22 @@ object TypeResolver extends LogSupport {
.foldLeft(plan) { (targetPlan, rule) =>
val r = rule.apply(analyzerContext)
// Recursively transform the tree
targetPlan.transform(r)
val resolved = targetPlan.transform(r)
resolved
}
resolvedPlan
}

def resolveRelation(analyzerContext: AnalyzerContext, plan: LogicalPlan): Relation = {
val resolvedPlan = resolve(analyzerContext, plan)
resolvedPlan match {
case r: Relation =>
r
case other =>
throw SQLErrorCode.InvalidArgument.newException(s"${plan} isn't a relation")
}
}

/**
* Translate select i1, i2, ... group by 1, 2, ... query into select i1, i2, ... group by i1, i2
*
Expand Down Expand Up @@ -96,7 +108,35 @@ object TypeResolver extends LogSupport {
}
}

def resolveRelation(context: AnalyzerContext): PlanRewriter = {
def resolveJoinUsing(context: AnalyzerContext): PlanRewriter = {
case j @ Join(joinType, left, right, u @ JoinUsing(joinKeys)) =>
// from A join B using(c1, c2, ...)
val resolvedJoin = Join(joinType, resolveRelation(context, left), resolveRelation(context, right), u)
val resolvedJoinKeys: Seq[Expression] = joinKeys.flatMap { k =>
findMatchInInputAttributes(k, resolvedJoin.inputAttributes) match {
case Nil =>
throw SQLErrorCode.ColumnNotFound.newException(s"join key column: ${k.sqlExpr} is not found")
case other =>
other
}
}
val updated = resolvedJoin.withCond(JoinOnEq(resolvedJoinKeys))
updated
case j @ Join(joinType, left, right, u @ JoinOn(Eq(leftKey, rightKey))) =>
val resolvedJoin = Join(joinType, resolveRelation(context, left), resolveRelation(context, right), u)
val resolvedJoinKeys: Seq[Expression] = Seq(leftKey, rightKey).flatMap { k =>
findMatchInInputAttributes(k, resolvedJoin.inputAttributes) match {
case Nil =>
throw SQLErrorCode.ColumnNotFound.newException(s"join key column: ${k.sqlExpr} is not found")
case other =>
other
}
}
val updated = resolvedJoin.withCond(JoinOnEq(resolvedJoinKeys))
updated
}

def resolveRegularRelation(context: AnalyzerContext): PlanRewriter = {
case filter @ Filter(child, filterExpr) =>
filter.transformExpressions { case x: Expression => resolveExpression(x, filter.inputAttributes) }
case r: Relation =>
Expand All @@ -111,10 +151,18 @@ object TypeResolver extends LogSupport {

def resolveColumns(context: AnalyzerContext): PlanRewriter = { case p @ Project(child, columns) =>
val resolvedColumns = resolveOutputColumns(child.outputAttributes, columns)
Project(child, resolvedColumns)
val resolved = Project(child, resolvedColumns)
resolved
}

/**
* Resolve output columns by looking up the inputAttributes
* @param inputAttributes
* @param outputColumns
* @return
*/
private def resolveOutputColumns(inputAttributes: Seq[Attribute], outputColumns: Seq[Attribute]): Seq[Attribute] = {

val resolvedColumns = Seq.newBuilder[Attribute]
outputColumns.map {
case a: AllColumns =>
Expand All @@ -132,7 +180,8 @@ object TypeResolver extends LogSupport {
case other =>
resolvedColumns += other
}
resolvedColumns.result()
val output = resolvedColumns.result()
output
}

def resolveAttribute(attribute: Attribute): Attribute = {
Expand All @@ -146,40 +195,49 @@ object TypeResolver extends LogSupport {
}

/**
* Resolve untyped expressions
* Find matching expressions in the inputAttributes
* @param expr
* @param inputAttributes
* @return
*/
def resolveExpression(expr: Expression, inputAttributes: Seq[Attribute]): Expression = {
def findInputAttribute(name: String): Option[Attribute] = {
def findMatchInInputAttributes(expr: Expression, inputAttributes: Seq[Attribute]): List[Expression] = {
def lookup(name: String): List[Attribute] = {
QName(name) match {
case QName(Seq(t1, c1)) =>
val attrs = inputAttributes.collect {
case a @ ResolvedAttribute(c, _, Some(t), _) if t.name == t1 && c == c1 => a
case a @ ResolvedAttribute(c, _, None, _) if c == c1 => a
}
if (attrs.size > 1) {
throw SQLErrorCode.SyntaxError.newException(s"${name} is ambiguous")
}
attrs.headOption
attrs.toList
case QName(Seq(c1)) =>
val attrs = inputAttributes.collect {
case a @ ResolvedAttribute(c, _, _, _) if c == c1 => a
}
if (attrs.size > 1) {
throw SQLErrorCode.SyntaxError.newException(s"${name} is ambiguous")
}
attrs.headOption
attrs.toList
case _ =>
None
List.empty
}
}

expr match {
case i: Identifier =>
findInputAttribute(i.value).getOrElse(i)
lookup(i.value)
case u @ UnresolvedAttribute(name) =>
findInputAttribute(name).getOrElse(u)
lookup(name)
case _ =>
expr
List(expr)
}
}

/**
* Resolve untyped expressions
*/
def resolveExpression(expr: Expression, inputAttributes: Seq[Attribute]): Expression = {
findMatchInInputAttributes(expr, inputAttributes) match {
case lst if lst.length > 1 =>
throw SQLErrorCode.SyntaxError.newException(s"${expr.sqlExpr} is ambiguous")
case lst =>
lst.headOption.getOrElse(expr)
}
}

Expand Down
Expand Up @@ -106,6 +106,24 @@ trait Attribute extends LeafExpression {
object Expression {
import wvlet.airframe.sql.model.LogicalPlan.Relation

def concat(expr: Seq[Expression])(merger: (Expression, Expression) => Expression): Expression = {
require(expr.length > 0)
if (expr.length == 1) {
expr.head
} else {
expr.tail.foldLeft(expr.head) { case (prev, next) =>
merger(prev, next)
}
}
}

def concatWithAnd(expr: Seq[Expression]): Expression = {
concat(expr) { case (a, b) => And(a, b) }
}
def concatWithEq(expr: Seq[Expression]): Expression = {
concat(expr) { case (a, b) => Eq(a, b) }
}

/**
*/
case class ParenthesizedExpression(child: Expression) extends UnaryExpression
Expand Down Expand Up @@ -157,6 +175,33 @@ object Expression {
override def child: Expression = expr
}

/**
* Join condition used only when join keys are resolved
* @param leftKey
* @param rightKey
*/
case class JoinOnEq(keys: Seq[Expression]) extends JoinCriteria with LeafExpression {
require(keys.forall(_.resolved), s"all keys of JoinOnEq must be resolved: ${keys}")

/**
* Report duplicate name join keys, which can be excluded from the parent
* @return
*/
def duplicateKeys: Seq[Expression] = {
// remove duplicate column names
var seen = Set.empty[String]
val uniqueNameKeys = keys.collect {
case r: ResolvedAttribute if !seen.contains(r.name) =>
seen += r.name
r
}
keys.collect {
case x if !uniqueNameKeys.contains(x) => x
}
}
override def children: Seq[Expression] = keys
}

case class AllColumns(prefix: Option[QName]) extends Attribute {
override def name: String = prefix.map(x => s"${x}.*").getOrElse("*")
override def children: Seq[Expression] = prefix.toSeq
Expand Down
Expand Up @@ -354,7 +354,17 @@ object LogicalPlan {
}
override def inputAttributes: Seq[Attribute] =
left.outputAttributes ++ right.outputAttributes
override def outputAttributes: Seq[Attribute] = inputAttributes
override def outputAttributes: Seq[Attribute] = {
cond match {
case je: JoinOnEq =>
// Remove join key duplication here
val dups = je.duplicateKeys
inputAttributes.filter(x => !dups.contains(x))
case _ => inputAttributes
}
}

def withCond(cond: JoinCriteria): Join = this.copy(cond = cond)
}
sealed abstract class JoinType(val symbol: String)
// Exact match (= equi join)
Expand Down
Expand Up @@ -211,6 +211,7 @@ object SQLGenerator extends LogSupport {
case NaturalJoin => ""
case JoinUsing(columns) => s" USING (${columns.map(_.sqlExpr).mkString(", ")})"
case JoinOn(expr) => s" ON ${printExpression(expr)}"
case JoinOnEq(keys) => s" ON ${printExpression(Expression.concatWithEq(keys))}"
}
joinType match {
case InnerJoin => s"${l} JOIN ${r}${c}"
Expand Down
Expand Up @@ -210,4 +210,34 @@ class TypeResolverTest extends AirSpec {
)
}
}

test("resolve join attributes") {
test("join with USING") {
val p = analyze("select id, A.name from A join B using(id)")
p.outputAttributes shouldBe List(
ResolvedAttribute("id", DataType.LongType, Some(tableA), Some(a1)),
ResolvedAttribute("name", DataType.StringType, Some(tableA), Some(a2))
)
}

test("join with on") {
val p = analyze("select id, A.name from A join B on A.id = B.id")
p.outputAttributes shouldBe List(
ResolvedAttribute("id", DataType.LongType, Some(tableA), Some(a1)),
ResolvedAttribute("name", DataType.StringType, Some(tableA), Some(a2))
)
}

test("join with different column names") {
val p = analyze("select pid, name from A join (select id pid from B) on A.id = B.pid")
p.outputAttributes shouldBe List(
ResolvedAttribute("pid", DataType.LongType, Some(tableB), Some(b1)),
ResolvedAttribute("name", DataType.StringType, Some(tableA), Some(a2))
)
}

test("3-way joins") {
pending("TODO")
}
}
}

0 comments on commit f874b89

Please sign in to comment.