Skip to content

Commit

Permalink
Merge pull request #4 from takezoe/sql_validation_against_db
Browse files Browse the repository at this point in the history
SQL validation against real database
  • Loading branch information
takezoe committed Sep 26, 2016
2 parents 543e90a + f25931d commit d4daad1
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 73 deletions.
5 changes: 3 additions & 2 deletions build.sbt
Expand Up @@ -9,8 +9,7 @@ scalaVersion := "2.11.8"
libraryDependencies ++= Seq(
"com.github.jsqlparser" % "jsqlparser" % "0.9.6",
"org.scalamacros" %% "resetallattrs" % "1.0.0",
"com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.7.2",
"com.github.pathikrit" %% "better-files" % "2.15.0",
"com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.7.2",
"org.scala-lang" % "scala-reflect" % scalaVersion.value,
"org.scala-lang" % "scala-compiler" % scalaVersion.value % "provided",
"org.scalatest" %% "scalatest" % "2.2.1" % "test"
Expand All @@ -26,6 +25,8 @@ publishTo <<= version { (v: String) =>

scalacOptions := Seq("-deprecation")

//unmanagedClasspath in Compile += baseDirectory.value / "src" / "main" / "resources"

publishArtifact in Test := false

pomIncludeRepository := { _ => false }
Expand Down
55 changes: 16 additions & 39 deletions src/main/scala/com/github/takezoe/scala/jdbc/DB.scala
Expand Up @@ -2,6 +2,7 @@ package com.github.takezoe.scala.jdbc

import java.sql._
import scala.reflect.ClassTag
import IOUtils._

object DB {

Expand All @@ -23,19 +24,12 @@ class DB(conn: Connection, typeMapper: TypeMapper){

def selectFirst[T](template: SqlTemplate)(f: ResultSet => T): Option[T] = {
execute(conn, template){ stmt =>
try {
val rs = stmt.executeQuery()
try {
if(rs.next){
Some(f(rs))
} else {
None
}
} finally {
rs.close()
using(stmt.executeQuery()){ rs =>
if(rs.next){
Some(f(rs))
} else {
None
}
} finally {
stmt.close()
}
}
}
Expand Down Expand Up @@ -108,19 +102,12 @@ class DB(conn: Connection, typeMapper: TypeMapper){

def select[T](template: SqlTemplate)(f: ResultSet => T): Seq[T] = {
execute(conn, template){ stmt =>
try {
val rs = stmt.executeQuery()
try {
val list = new scala.collection.mutable.ListBuffer[T]
while(rs.next){
list += f(rs)
}
list.toSeq
} finally {
rs.close()
using(stmt.executeQuery()){ rs =>
val list = new scala.collection.mutable.ListBuffer[T]
while(rs.next){
list += f(rs)
}
} finally {
stmt.close()
list.toSeq
}
}
}
Expand Down Expand Up @@ -201,17 +188,10 @@ class DB(conn: Connection, typeMapper: TypeMapper){

def scan[T](template: SqlTemplate)(f: ResultSet => Unit): Unit = {
execute(conn, template){ stmt =>
try {
val rs = stmt.executeQuery()
try {
while(rs.next){
f(rs)
}
} finally {
rs.close()
using(stmt.executeQuery()){ rs =>
while(rs.next){
f(rs)
}
} finally {
stmt.close()
}
}
}
Expand Down Expand Up @@ -298,22 +278,19 @@ class DB(conn: Connection, typeMapper: TypeMapper){
r
} catch {
case e: Throwable =>
conn.rollback()
rollbackQuietly(conn)
throw e
}
}

def close(): Unit = conn.close()

protected def execute[T](conn: Connection, template: SqlTemplate)(f: (PreparedStatement) => T): T = {
val stmt = conn.prepareStatement(template.sql)
try {
using(conn.prepareStatement(template.sql)){ stmt =>
template.params.zipWithIndex.foreach { case (x, i) =>
typeMapper.set(stmt, i + 1, x)
}
f(stmt)
} finally {
stmt.close()
}
}

Expand Down
46 changes: 46 additions & 0 deletions src/main/scala/com/github/takezoe/scala/jdbc/IOUtils.scala
@@ -0,0 +1,46 @@
package com.github.takezoe.scala.jdbc

import java.io.{ByteArrayOutputStream, InputStream}
import java.sql.Connection

object IOUtils {

def closeQuietly(closeable: AutoCloseable): Unit = {
if(closeable != null){
try {
closeable.close()
} catch {
case e: Exception => // Ignore
}
}
}

def rollbackQuietly(conn: Connection): Unit = {
try {
conn.rollback()
} catch {
case e: Exception => e.printStackTrace()
}
}

def using[T <: AutoCloseable, R](closeable: T)(f: T => R): R = {
try {
f(closeable)
} finally {
closeQuietly(closeable)
}
}


def readStreamAsString(in: InputStream): String = {
val buf = new Array[Byte](1024 * 8)
var length = 0
using(new ByteArrayOutputStream()) { out =>
while ({ length = in.read(buf); length } != -1) {
out.write(buf, 0, length)
}
new String(out.toByteArray, "UTF-8")
}
}

}
13 changes: 9 additions & 4 deletions src/main/scala/com/github/takezoe/scala/jdbc/package.scala
Expand Up @@ -39,28 +39,33 @@ object Macros {
import c.universe._
sql.tree match {
case Literal(x) => x.value match {
case sql: String => SqlValidator.validateSql(sql, c)
case sql: String => SqlValidator.validateSql(sql, Nil, c)
val Apply(fun, _) = reify(new SqlTemplate("")).tree
c.Expr[com.github.takezoe.scala.jdbc.SqlTemplate](Apply.apply(fun, Literal(x) :: Nil))
}
case Apply(Select(Apply(Select(Select((_, _)), _), trees), _), args) => {
val sql = trees.collect { case Literal(x) => x.value.asInstanceOf[String] }.mkString("?")
SqlValidator.validateSql(sql, c)
SqlValidator.validateSql(sql, args.map(_.tpe.toString), c)
val Apply(fun, _) = reify(new SqlTemplate("")).tree

args.foreach { arg =>
println(arg.tpe.getClass)
}

c.Expr[SqlTemplate](Apply.apply(fun, Literal(Constant(sql)) :: args))
}
case Select(Apply(Select(a, b), List(Literal(x))), TermName("stripMargin")) => {
x.value match {
case s: String =>
val sql = s.stripMargin
SqlValidator.validateSql(sql, c)
SqlValidator.validateSql(sql, Nil, c)
val Apply(fun, _) = reify(new SqlTemplate("")).tree
c.Expr[SqlTemplate](Apply.apply(fun, Literal(Constant(sql)) :: Nil))
}
}
case Select(Apply(_, List(Apply(Select(Apply(Select(Select((_, _)), _), trees), _), args))), TermName("stripMargin")) => {
val sql = trees.collect { case Literal(x) => x.value.asInstanceOf[String] }.mkString("?").stripMargin
SqlValidator.validateSql(sql, c)
SqlValidator.validateSql(sql, args.map(_.tpe.toString), c)
val Apply(fun, _) = reify(new SqlTemplate("")).tree
c.Expr[SqlTemplate](Apply.apply(fun, Literal(Constant(sql)) :: args))
}
Expand Down
@@ -1,10 +1,18 @@
package com.github.takezoe.scala.jdbc.validation

import better.files.File
import java.io.{File, FileInputStream}

import com.github.takezoe.scala.jdbc.IOUtils._
import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
import com.fasterxml.jackson.module.scala.DefaultScalaModule

case class SchemaDef(tables: Seq[TableDef])
case class SchemaDef(tables: Seq[TableDef], connection: Option[ConnectionDef]){
def toMap: Map[String, TableDef] = {
tables.map { t => t.name -> t }.toMap
}
}

case class ConnectionDef(driver: String, url: String, user: String, password: String)

case class TableDef(name:String, columns: Seq[ColumnDef])

Expand All @@ -17,16 +25,23 @@ object SchemaDef {
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
mapper.registerModule(DefaultScalaModule)

def load(): Map[String, TableDef] = {
val file = File("schema.json")
val schema: Map[String, TableDef] = if(file.exists){
val json = file.contentAsString
val schema = mapper.readValue(json, classOf[SchemaDef])
schema.tables.map { t => t.name -> t }.toMap
def load(): Option[SchemaDef] = {
val file = new File("schema.json")
if(file.exists){
// Load from file system
val json = using(new FileInputStream(file)){ in =>
readStreamAsString(in)
}
Some(mapper.readValue(json, classOf[SchemaDef]))
} else {
Map.empty
val in = Thread.currentThread.getContextClassLoader.getResourceAsStream("schema.json")
Option(in).map { in =>
// Load from classpath
val json = using(in){ in =>
readStreamAsString(in)
}
mapper.readValue(json, classOf[SchemaDef])
}
}
schema
}

}
@@ -1,5 +1,7 @@
package com.github.takezoe.scala.jdbc.validation

import java.sql.{Date, DriverManager, Time, Timestamp}

import net.sf.jsqlparser.JSQLParserException
import net.sf.jsqlparser.parser.CCJSqlParserUtil
import net.sf.jsqlparser.statement.StatementVisitorAdapter
Expand All @@ -9,28 +11,87 @@ import net.sf.jsqlparser.statement.update.Update

import scala.reflect.macros.blackbox.Context

import com.github.takezoe.scala.jdbc.IOUtils._
import com.github.takezoe.scala.jdbc.TypeMapper

object SqlValidator {

def validateSql(sql: String, c: Context): Unit = {
val schema = SchemaDef.load()
try {
val parse = CCJSqlParserUtil.parse(sql)
parse.accept(new StatementVisitorAdapter {
override def visit(select: net.sf.jsqlparser.statement.select.Select): Unit = {
new SelectValidator(c, select, schema).validate()
}
override def visit(insert: Insert): Unit = {
new InsertValidator(c, insert, schema).validate()
val typeMapper = new TypeMapper() // TODO It should be replaceable.

def validateSql(sql: String, types: Seq[String], c: Context): Unit = {
SchemaDef.load() match {
case None => {
try {
CCJSqlParserUtil.parse(sql)
} catch {
case e: JSQLParserException => c.error(c.enclosingPosition, e.getCause.getMessage)
}
override def visit(update: Update): Unit = {
new UpdateValidator(c, update, schema).validate()
}
case Some(SchemaDef(_, Some(connection))) => {
Class.forName(connection.driver)
val conn = DriverManager.getConnection(connection.url, connection.user, connection.password)
try {
conn.setAutoCommit(false)
using(conn.prepareStatement(adjustSql(sql))){ stmt =>
try {
types.zipWithIndex.foreach { case (t, i) =>
typeMapper.set(stmt, i + 1, getTestValue(t))
}
stmt.execute()
} catch {
case e: Exception => c.error(c.enclosingPosition, e.toString)
}
}
} finally {
rollbackQuietly(conn)
closeQuietly(conn)
}
override def visit(delete: Delete): Unit = {
new DeleteValidator(c, delete, schema).validate()
}
case Some(schemaDef) => {
try {
val parse = CCJSqlParserUtil.parse(sql)
val schema = schemaDef.toMap
parse.accept(new StatementVisitorAdapter {
override def visit(select: net.sf.jsqlparser.statement.select.Select): Unit = {
new SelectValidator(c, select, schema).validate()
}
override def visit(insert: Insert): Unit = {
new InsertValidator(c, insert, schema).validate()
}
override def visit(update: Update): Unit = {
new UpdateValidator(c, update, schema).validate()
}
override def visit(delete: Delete): Unit = {
new DeleteValidator(c, delete, schema).validate()
}
})
} catch {
case e: JSQLParserException => c.error(c.enclosingPosition, e.getCause.getMessage)
}
})
} catch {
case e: JSQLParserException => c.error(c.enclosingPosition, e.getCause.getMessage)
}
}
}

private def adjustSql(sql: String): String = {
if(sql.trim.toUpperCase.startsWith("SELECT")){
sql + " LIMIT 0"
} else {
sql
}
}

// TODO Move to TypeMapper?
private def getTestValue(t: String): Any = {
t match {
case "Int" => 0
case "Long" => 0L
case "Double" => 0D
case "Short" => 0:Short
case "Float" => 0F
case "java.sql.Timestamp" => new Timestamp(System.currentTimeMillis)
case "java.sql.Date" => new Date(System.currentTimeMillis)
case "java.sql.Time" => new Time(System.currentTimeMillis)
case "String" => "-"
}
}

Expand Down

0 comments on commit d4daad1

Please sign in to comment.