Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQL validation macro #1

Merged
merged 23 commits into from
Sep 10, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,47 @@ DB.autoClose(conn) { db =>
println(rs.getString("USER_NAME"))
}
}
```
```

## SQL Validation (Experimental)

scala-jdbc provides `sqlc` macro that validates a given SQL. You can use it instead of sql string interpolation.

```scala
db.selectFirst(sqlc("SELECT * FROM USERS WHERE USER_ID = $userId")){ rs =>
(rs.getInt("USER_ID"), rs.getString("USER_NAME"))
}
```

When a given SQL is invalid, this macro reports error in compile time.

In default, this macro checks only sql syntax.
It's also possible to check existence of tables and columns by database schema definition by defining `schema.json` in the current directory.
Here is an example of `schema.json`:

```javascript
{
"tables":[
{
"name": "USER",
"columns": [
{ "name": "USER_ID" },
{ "name": "USER_NAME" },
{ "name": "DEPT_ID" }
]
},
{
"name": "COMPANY",
"columns": [
{ "name": "COMPANY_ID" },
{ "name": "COMPANY_NAME" }
]
}
]
}
```

However `sqlc` macro is still experimental feature.
If you get invalid validation results, please report them to [issues](https://github.com/takezoe/scala-jdbc/issues).


12 changes: 11 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,20 @@ name := "scala-jdbc"

organization := "com.github.takezoe"

version := "1.0.2"
version := "1.0.3-SNAPSHOT"

scalaVersion := "2.11.8"

libraryDependencies ++= Seq(
"com.github.jsqlparser" % "jsqlparser" % "0.9.6",
"org.scalamacros" %% "resetallattrs" % "1.0.0",
"com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.7.2",
"com.github.pathikrit" %% "better-files" % "2.15.0",
"org.scala-lang" % "scala-reflect" % scalaVersion.value,
"org.scala-lang" % "scala-compiler" % scalaVersion.value % "provided",
"org.scalatest" %% "scalatest" % "2.2.1" % "test"
)

publishMavenStyle := true

publishTo <<= version { (v: String) =>
Expand Down
49 changes: 49 additions & 0 deletions schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
{
"tables":[
{
"name": "USER",
"columns": [
{
"name": "USER_ID"
},
{
"name": "USER_NAME"
},
{
"name": "DEPT_ID"
}
]
},
{
"name": "COMPANY",
"columns": [
{
"name": "COMPANY_ID"
},
{
"name": "COMPANY_NAME"
}
]
},
{
"name": "DEPT",
"columns": [
{
"name": "DEPT_ID"
},
{
"name": "DEPT_NAME"
}
]
},
{
"name": "DEPT_GROUP",
"columns": [
{
"name": "DEPT_ID"
}
]
}

]
}
53 changes: 52 additions & 1 deletion src/main/scala/com/github/takezoe/scala/jdbc/package.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
package com.github.takezoe.scala

import com.github.takezoe.scala.jdbc.SqlTemplate

import scala.language.experimental.macros
import scala.reflect.macros.blackbox.Context
import com.github.takezoe.scala.jdbc.validation._

package object jdbc {

/**
Expand All @@ -11,9 +17,54 @@ package object jdbc {
* String interpolation to convert a variable-embedded SQL to SqlTemplate
*/
implicit class SqlStringInterpolation(val sc: StringContext) extends AnyVal {
def sql(args: Any*): SqlTemplate = SqlTemplate(sc.parts.mkString("?"), args.toSeq)
def sql(args: Any*): SqlTemplate = {
val sql = sc.parts.mkString
SqlTemplate(sql, args.toSeq)
}
}

case class SqlTemplate(sql: String, params: Any*)

/**
* Macro version of sql string interpolation.
* This macro validates the given sql in compile time and returns SqlTemplate as same as string interpolation.
*/
def sqlc(sql: String): com.github.takezoe.scala.jdbc.SqlTemplate = macro Macros.validateSqlMacro

}

object Macros {

def validateSqlMacro(c: Context)(sql: c.Expr[String]): c.Expr[com.github.takezoe.scala.jdbc.SqlTemplate] = {
import c.universe._
sql.tree match {
case Literal(x) => x.value match {
case sql: String => SqlValidator.validateSql(sql, c)
val Apply(fun, _) = reify(new SqlTemplate("")).tree
c.Expr[com.github.takezoe.scala.jdbc.SqlTemplate](Apply.apply(fun, Literal(x) :: Nil))
}
case Apply(Select(Apply(Select(Select((_, _)), _), trees), _), args) => {
val sql = trees.collect { case Literal(x) => x.value.asInstanceOf[String] }.mkString("?")
SqlValidator.validateSql(sql, c)
val Apply(fun, _) = reify(new SqlTemplate("")).tree
c.Expr[SqlTemplate](Apply.apply(fun, Literal(Constant(sql)) :: args))
}
case Select(Apply(Select(a, b), List(Literal(x))), TermName("stripMargin")) => {
x.value match {
case s: String =>
val sql = s.stripMargin
SqlValidator.validateSql(sql, c)
val Apply(fun, _) = reify(new SqlTemplate("")).tree
c.Expr[SqlTemplate](Apply.apply(fun, Literal(Constant(sql)) :: Nil))
}
}
case Select(Apply(_, List(Apply(Select(Apply(Select(Select((_, _)), _), trees), _), args))), TermName("stripMargin")) => {
val sql = trees.collect { case Literal(x) => x.value.asInstanceOf[String] }.mkString("?").stripMargin
SqlValidator.validateSql(sql, c)
val Apply(fun, _) = reify(new SqlTemplate("")).tree
c.Expr[SqlTemplate](Apply.apply(fun, Literal(Constant(sql)) :: args))
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.github.takezoe.scala.jdbc.validation

import net.sf.jsqlparser.expression.ExpressionVisitorAdapter
import net.sf.jsqlparser.schema.Column
import net.sf.jsqlparser.statement.delete.Delete
import net.sf.jsqlparser.statement.select.SubSelect

import scala.reflect.macros.blackbox

class DeleteValidator(c: blackbox.Context, delete: Delete, schema: Map[String, TableDef]) {

def validate(): Unit = {
val tableName = delete.getTable.getName

schema.get(tableName) match {
case None => if(schema.nonEmpty){
c.error(c.enclosingPosition, "Table " + tableName + " does not exist.")
}
case Some(tableDef) => {
val select = new SelectModel()
val tableModel = new TableModel()
tableModel.select = Left(tableName)
select.from += tableModel

delete.getWhere.accept(new ExpressionVisitorAdapter {
override def visit(column: Column): Unit = {
val c = new ColumnModel()
c.name = column.getColumnName
c.table = Option(tableName)
select.where += c
}

override def visit(subSelect: SubSelect): Unit = {
val visitor = new SelectVisitor(c)
subSelect.getSelectBody.accept(visitor)
select.others += visitor.select
}
})

select.validate(c, schema)
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.github.takezoe.scala.jdbc.validation

import net.sf.jsqlparser.statement.insert.Insert

import scala.collection.JavaConverters._
import scala.reflect.macros.blackbox

class InsertValidator(c: blackbox.Context, insert: Insert, schema: Map[String, TableDef]) {

def validate(): Unit = {
val tableName = insert.getTable.getName

schema.get(tableName) match {
case None => if(schema.nonEmpty){
c.error(c.enclosingPosition, "Table " + tableName + " does not exist.")
}
case Some(tableDef) => insert.getColumns.asScala.foreach { column =>
if(!tableDef.columns.exists(_.name == column.getColumnName)){
c.error(c.enclosingPosition, "Column " + column.getColumnName + " does not exist in " + tableDef.name + ".")
}
}
}

if(insert.getSelect != null){
val visitor = new SelectVisitor(c)
insert.getSelect.getSelectBody.accept(visitor)
visitor.select.validate(c, schema)
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.github.takezoe.scala.jdbc.validation

import better.files.File
import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
import com.fasterxml.jackson.module.scala.DefaultScalaModule

case class SchemaDef(tables: Seq[TableDef])

case class TableDef(name:String, columns: Seq[ColumnDef])

case class ColumnDef(name: String)

object SchemaDef {

private val mapper = new ObjectMapper()
mapper.enable(DeserializationFeature.UNWRAP_SINGLE_VALUE_ARRAYS)
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
mapper.registerModule(DefaultScalaModule)

def load(): Map[String, TableDef] = {
val file = File("schema.json")
val schema: Map[String, TableDef] = if(file.exists){
val json = file.contentAsString
val schema = mapper.readValue(json, classOf[SchemaDef])
schema.tables.map { t => t.name -> t }.toMap
} else {
Map.empty
}
schema
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.github.takezoe.scala.jdbc.validation

import net.sf.jsqlparser.statement.select.Select

import scala.reflect.macros.blackbox

class SelectValidator(c: blackbox.Context, select: Select, schema: Map[String, TableDef]) {

def validate(): Unit = {
val visitor = new SelectVisitor(c)
select.getSelectBody.accept(visitor)

visitor.select.validate(c, schema)
}

}
Loading