Skip to content

Commit

Permalink
sql parser 支持from List
Browse files Browse the repository at this point in the history
  • Loading branch information
wangzhang7982 committed Jul 16, 2023
1 parent ed91b76 commit 41c70e2
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 22 deletions.
7 changes: 2 additions & 5 deletions core/src/main/scala/easysql/ast/statement/SqlStatement.scala
Expand Up @@ -43,7 +43,7 @@ sealed trait SqlQuery
case class SqlSelect(
param: Option[String],
select: List[SqlSelectItem],
from: Option[SqlTable],
from: List[SqlTable],
where: Option[SqlExpr],
groupBy: List[SqlExpr],
orderBy: List[SqlOrderBy],
Expand All @@ -64,10 +64,7 @@ case class SqlSelect(
this.copy(
select = this.select ++ that.select,

from = for {
f <- this.from
tf <- that.from
} yield SqlJoinTable(f, SqlJoinType.InnerJoin, tf, None),
from = this.from ++ that.from,

where = (this.where, that.where) match {
case (Some(w), Some(tw)) => Some(SqlBinaryExpr(w, SqlBinaryOperator.And, tw))
Expand Down
6 changes: 3 additions & 3 deletions core/src/main/scala/easysql/parser/SqlParser.scala
Expand Up @@ -220,7 +220,7 @@ class SqlParser extends StandardTokenParsers {
if distinct.isDefined then Some("DISTINCT")
else None
SqlQueryExpr(
SqlSelect(param, s, f, w, g.map(_._1).getOrElse(Nil), o.getOrElse(Nil), false, l, g.map(_._2).getOrElse(None))
SqlSelect(param, s, f.getOrElse(Nil), w, g.map(_._1).getOrElse(Nil), o.getOrElse(Nil), false, l, g.map(_._2).getOrElse(None))
)
}
}
Expand Down Expand Up @@ -266,8 +266,8 @@ class SqlParser extends StandardTokenParsers {
}
}

def from: Parser[SqlTable] =
"FROM" ~> joinTable
def from: Parser[List[SqlTable]] =
"FROM" ~> rep1sep(joinTable, ",")

def where: Parser[SqlExpr] =
"WHERE" ~> expr
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/easysql/printer/ESPrinter.scala
Expand Up @@ -14,7 +14,7 @@ class ESPrinter {

def printSelect(s: SqlSelect) = {
s.from match {
case Some(SqlIdentTable(t, None)) => dslBuilder.append(s"GET /$t/_search")
case SqlIdentTable(t, None) :: Nil => dslBuilder.append(s"GET /$t/_search")
case _ =>
}
dslBuilder.append(" {\n")
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/easysql/printer/MongoPrinter.scala
Expand Up @@ -14,7 +14,7 @@ class MongoPrinter {

def printSelect(s: SqlSelect): Unit = {
s.from match {
case Some(SqlIdentTable(t, None)) => dslBuilder.append(s"db.$t.find(")
case SqlIdentTable(t, None) :: Nil => dslBuilder.append(s"db.$t.find(")
case _ =>
}
s.where match {
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/easysql/printer/SqlPrinter.scala
Expand Up @@ -122,11 +122,11 @@ trait SqlPrinter(val prepare: Boolean) {
printList(select.select)(printSelectItem)
}

select.from.foreach { it =>
if (select.from.nonEmpty) {
sqlBuilder.append("\n")
printSpace(spaceNum)
sqlBuilder.append("FROM ")
printTable(it)
printList(select.from)(printTable)
}

select.where.foreach { it =>
Expand Down
4 changes: 2 additions & 2 deletions core/src/main/scala/easysql/query/select/MonadicQuery.scala
Expand Up @@ -202,7 +202,7 @@ object MonadicQuery {
val sqlSelectItems = table.__cols.map { c =>
SqlSelectItem(exprToSqlExpr(c), None)
}
val query: SqlSelect = SqlSelect(None, sqlSelectItems, Some(fromTable), None, Nil, Nil, false, None, None)
val query: SqlSelect = SqlSelect(None, sqlSelectItems, List(fromTable), None, Nil, Nil, false, None, None)

new MonadicQuery(query, table)
}
Expand All @@ -221,7 +221,7 @@ object MonadicQuery {
def exists: MonadicQuery[Tuple1[Boolean], None.type] = {
val expr = SqlExprFuncExpr("EXISTS", SqlQueryExpr(q.query) :: Nil)
val selectItem = SqlSelectItem(expr, None)
val newQuery: SqlSelect = SqlSelect(None, selectItem :: Nil, None, None, Nil, Nil, false, None, None)
val newQuery: SqlSelect = SqlSelect(None, selectItem :: Nil, Nil, None, Nil, Nil, false, None, None)

new MonadicQuery(newQuery, None)
}
Expand Down
16 changes: 8 additions & 8 deletions core/src/main/scala/easysql/query/select/Select.scala
Expand Up @@ -25,20 +25,20 @@ class Select[T <: Tuple, A <: Tuple](

infix def from(table: TableSchema[_]): Select[T, A] = {
val fromTable = SqlIdentTable(table.__tableName, table.__aliasName)
new Select(ast.copy(from = Some(fromTable)), selectItems, Some(fromTable))
new Select(ast.copy(from = List(fromTable)), selectItems, Some(fromTable))
}

infix def from(table: AliasQuery[_, _])(using inWithQuery: InWithQuery = NotIn): Select[T, A] = {
val fromTable =
if inWithQuery == In
then SqlIdentTable(table.__queryName, None)
else SqlSubQueryTable(table.__ast, false, Some(table.__queryName))
new Select(ast.copy(from = Some(fromTable)), selectItems, Some(fromTable))
new Select(ast.copy(from = List(fromTable)), selectItems, Some(fromTable))
}

infix def fromLateral(table: AliasQuery[_, _]): Select[T, A] = {
val fromTable = SqlSubQueryTable(table.__ast, true, Some(table.__queryName))
new Select(ast.copy(from = Some(fromTable)), selectItems, Some(fromTable))
new Select(ast.copy(from = List(fromTable)), selectItems, Some(fromTable))
}

infix def select[U <: Tuple](items: U): Select[Concat[T, InverseMap[U]], Concat[A, AliasNames[U]]] = {
Expand Down Expand Up @@ -146,7 +146,7 @@ class Select[T <: Tuple, A <: Tuple](
case Some(value) => SqlJoinTable(value, joinType, joinTable, None)
}

new Select(ast.copy(from = Some(fromTable)), selectItems, Some(fromTable))
new Select(ast.copy(from = List(fromTable)), selectItems, Some(fromTable))
}

private def joinClause(table: AliasQuery[_, _], joinType: SqlJoinType, lateral: Boolean)(using inWithQuery: InWithQuery = NotIn): Select[T, A] = {
Expand All @@ -160,7 +160,7 @@ class Select[T <: Tuple, A <: Tuple](
case Some(value) => SqlJoinTable(value, joinType, joinTable, None)
}

new Select(ast.copy(from = Some(fromTable)), selectItems, Some(fromTable))
new Select(ast.copy(from = List(fromTable)), selectItems, Some(fromTable))
}

private def joinClause(table: JoinTable, joinType: SqlJoinType): Select[T, A] = {
Expand All @@ -179,7 +179,7 @@ class Select[T <: Tuple, A <: Tuple](
case Some(value) => SqlJoinTable(value, joinType, joinTable, None)
}

new Select(ast.copy(from = Some(fromTable)), selectItems, Some(fromTable))
new Select(ast.copy(from = List(fromTable)), selectItems, Some(fromTable))
}

infix def on(expr: Expr[Boolean]): Select[T, A] = {
Expand All @@ -188,7 +188,7 @@ class Select[T <: Tuple, A <: Tuple](
case f => f
}

new Select(ast.copy(from = from), selectItems, from)
new Select(ast.copy(from = from), selectItems, from.headOption)
}

infix def join(table: TableSchema[_] | AliasQuery[_, _] | JoinTable)(using inWithQuery: InWithQuery = NotIn): Select[T, A] = table match {
Expand Down Expand Up @@ -251,7 +251,7 @@ class Select[T <: Tuple, A <: Tuple](

object Select {
def apply(): Select[EmptyTuple, EmptyTuple] =
new Select(SqlSelect(None, Nil, None, None, Nil, Nil, false, None, None), Map(), None)
new Select(SqlSelect(None, Nil, Nil, None, Nil, Nil, false, None, None), Map(), None)

given selectToCountSql: ToCountSql[Select[_, _]] with {
extension (x: Select[_, _]) {
Expand Down

0 comments on commit 41c70e2

Please sign in to comment.