Skip to content

Commit

Permalink
Merge pull request #1 from takezoe/sql_validation_macro
Browse files Browse the repository at this point in the history
SQL validation macro
  • Loading branch information
takezoe committed Sep 10, 2016
2 parents e74c365 + 2a024a7 commit 3655192
Show file tree
Hide file tree
Showing 12 changed files with 586 additions and 3 deletions.
45 changes: 44 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,47 @@ DB.autoClose(conn) { db =>
println(rs.getString("USER_NAME"))
}
}
```
```

## SQL Validation (Experimental)

scala-jdbc provides `sqlc` macro that validates a given SQL. You can use it instead of sql string interpolation.

```scala
db.selectFirst(sqlc("SELECT * FROM USERS WHERE USER_ID = $userId")){ rs =>
(rs.getInt("USER_ID"), rs.getString("USER_NAME"))
}
```

When a given SQL is invalid, this macro reports error in compile time.

In default, this macro checks only sql syntax.
It's also possible to check existence of tables and columns by database schema definition by defining `schema.json` in the current directory.
Here is an example of `schema.json`:

```javascript
{
"tables":[
{
"name": "USER",
"columns": [
{ "name": "USER_ID" },
{ "name": "USER_NAME" },
{ "name": "DEPT_ID" }
]
},
{
"name": "COMPANY",
"columns": [
{ "name": "COMPANY_ID" },
{ "name": "COMPANY_NAME" }
]
}
]
}
```

However `sqlc` macro is still experimental feature.
If you get invalid validation results, please report them to [issues](https://github.com/takezoe/scala-jdbc/issues).


12 changes: 11 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,20 @@ name := "scala-jdbc"

organization := "com.github.takezoe"

version := "1.0.2"
version := "1.0.3-SNAPSHOT"

scalaVersion := "2.11.8"

libraryDependencies ++= Seq(
"com.github.jsqlparser" % "jsqlparser" % "0.9.6",
"org.scalamacros" %% "resetallattrs" % "1.0.0",
"com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.7.2",
"com.github.pathikrit" %% "better-files" % "2.15.0",
"org.scala-lang" % "scala-reflect" % scalaVersion.value,
"org.scala-lang" % "scala-compiler" % scalaVersion.value % "provided",
"org.scalatest" %% "scalatest" % "2.2.1" % "test"
)

publishMavenStyle := true

publishTo <<= version { (v: String) =>
Expand Down
49 changes: 49 additions & 0 deletions schema.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
{
"tables":[
{
"name": "USER",
"columns": [
{
"name": "USER_ID"
},
{
"name": "USER_NAME"
},
{
"name": "DEPT_ID"
}
]
},
{
"name": "COMPANY",
"columns": [
{
"name": "COMPANY_ID"
},
{
"name": "COMPANY_NAME"
}
]
},
{
"name": "DEPT",
"columns": [
{
"name": "DEPT_ID"
},
{
"name": "DEPT_NAME"
}
]
},
{
"name": "DEPT_GROUP",
"columns": [
{
"name": "DEPT_ID"
}
]
}

]
}
53 changes: 52 additions & 1 deletion src/main/scala/com/github/takezoe/scala/jdbc/package.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
package com.github.takezoe.scala

import com.github.takezoe.scala.jdbc.SqlTemplate

import scala.language.experimental.macros
import scala.reflect.macros.blackbox.Context
import com.github.takezoe.scala.jdbc.validation._

package object jdbc {

/**
Expand All @@ -11,9 +17,54 @@ package object jdbc {
* String interpolation to convert a variable-embedded SQL to SqlTemplate
*/
implicit class SqlStringInterpolation(val sc: StringContext) extends AnyVal {
def sql(args: Any*): SqlTemplate = SqlTemplate(sc.parts.mkString("?"), args.toSeq)
def sql(args: Any*): SqlTemplate = {
val sql = sc.parts.mkString
SqlTemplate(sql, args.toSeq)
}
}

case class SqlTemplate(sql: String, params: Any*)

/**
* Macro version of sql string interpolation.
* This macro validates the given sql in compile time and returns SqlTemplate as same as string interpolation.
*/
def sqlc(sql: String): com.github.takezoe.scala.jdbc.SqlTemplate = macro Macros.validateSqlMacro

}

object Macros {

def validateSqlMacro(c: Context)(sql: c.Expr[String]): c.Expr[com.github.takezoe.scala.jdbc.SqlTemplate] = {
import c.universe._
sql.tree match {
case Literal(x) => x.value match {
case sql: String => SqlValidator.validateSql(sql, c)
val Apply(fun, _) = reify(new SqlTemplate("")).tree
c.Expr[com.github.takezoe.scala.jdbc.SqlTemplate](Apply.apply(fun, Literal(x) :: Nil))
}
case Apply(Select(Apply(Select(Select((_, _)), _), trees), _), args) => {
val sql = trees.collect { case Literal(x) => x.value.asInstanceOf[String] }.mkString("?")
SqlValidator.validateSql(sql, c)
val Apply(fun, _) = reify(new SqlTemplate("")).tree
c.Expr[SqlTemplate](Apply.apply(fun, Literal(Constant(sql)) :: args))
}
case Select(Apply(Select(a, b), List(Literal(x))), TermName("stripMargin")) => {
x.value match {
case s: String =>
val sql = s.stripMargin
SqlValidator.validateSql(sql, c)
val Apply(fun, _) = reify(new SqlTemplate("")).tree
c.Expr[SqlTemplate](Apply.apply(fun, Literal(Constant(sql)) :: Nil))
}
}
case Select(Apply(_, List(Apply(Select(Apply(Select(Select((_, _)), _), trees), _), args))), TermName("stripMargin")) => {
val sql = trees.collect { case Literal(x) => x.value.asInstanceOf[String] }.mkString("?").stripMargin
SqlValidator.validateSql(sql, c)
val Apply(fun, _) = reify(new SqlTemplate("")).tree
c.Expr[SqlTemplate](Apply.apply(fun, Literal(Constant(sql)) :: args))
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.github.takezoe.scala.jdbc.validation

import net.sf.jsqlparser.expression.ExpressionVisitorAdapter
import net.sf.jsqlparser.schema.Column
import net.sf.jsqlparser.statement.delete.Delete
import net.sf.jsqlparser.statement.select.SubSelect

import scala.reflect.macros.blackbox

class DeleteValidator(c: blackbox.Context, delete: Delete, schema: Map[String, TableDef]) {

def validate(): Unit = {
val tableName = delete.getTable.getName

schema.get(tableName) match {
case None => if(schema.nonEmpty){
c.error(c.enclosingPosition, "Table " + tableName + " does not exist.")
}
case Some(tableDef) => {
val select = new SelectModel()
val tableModel = new TableModel()
tableModel.select = Left(tableName)
select.from += tableModel

delete.getWhere.accept(new ExpressionVisitorAdapter {
override def visit(column: Column): Unit = {
val c = new ColumnModel()
c.name = column.getColumnName
c.table = Option(tableName)
select.where += c
}

override def visit(subSelect: SubSelect): Unit = {
val visitor = new SelectVisitor(c)
subSelect.getSelectBody.accept(visitor)
select.others += visitor.select
}
})

select.validate(c, schema)
}
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.github.takezoe.scala.jdbc.validation

import net.sf.jsqlparser.statement.insert.Insert

import scala.collection.JavaConverters._
import scala.reflect.macros.blackbox

class InsertValidator(c: blackbox.Context, insert: Insert, schema: Map[String, TableDef]) {

def validate(): Unit = {
val tableName = insert.getTable.getName

schema.get(tableName) match {
case None => if(schema.nonEmpty){
c.error(c.enclosingPosition, "Table " + tableName + " does not exist.")
}
case Some(tableDef) => insert.getColumns.asScala.foreach { column =>
if(!tableDef.columns.exists(_.name == column.getColumnName)){
c.error(c.enclosingPosition, "Column " + column.getColumnName + " does not exist in " + tableDef.name + ".")
}
}
}

if(insert.getSelect != null){
val visitor = new SelectVisitor(c)
insert.getSelect.getSelectBody.accept(visitor)
visitor.select.validate(c, schema)
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package com.github.takezoe.scala.jdbc.validation

import better.files.File
import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
import com.fasterxml.jackson.module.scala.DefaultScalaModule

case class SchemaDef(tables: Seq[TableDef])

case class TableDef(name:String, columns: Seq[ColumnDef])

case class ColumnDef(name: String)

object SchemaDef {

private val mapper = new ObjectMapper()
mapper.enable(DeserializationFeature.UNWRAP_SINGLE_VALUE_ARRAYS)
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
mapper.registerModule(DefaultScalaModule)

def load(): Map[String, TableDef] = {
val file = File("schema.json")
val schema: Map[String, TableDef] = if(file.exists){
val json = file.contentAsString
val schema = mapper.readValue(json, classOf[SchemaDef])
schema.tables.map { t => t.name -> t }.toMap
} else {
Map.empty
}
schema
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.github.takezoe.scala.jdbc.validation

import net.sf.jsqlparser.statement.select.Select

import scala.reflect.macros.blackbox

class SelectValidator(c: blackbox.Context, select: Select, schema: Map[String, TableDef]) {

def validate(): Unit = {
val visitor = new SelectVisitor(c)
select.getSelectBody.accept(visitor)

visitor.select.validate(c, schema)
}

}
Loading

0 comments on commit 3655192

Please sign in to comment.