diff --git a/build.sbt b/build.sbt index 68c8960..8cf09f4 100644 --- a/build.sbt +++ b/build.sbt @@ -9,8 +9,7 @@ scalaVersion := "2.11.8" libraryDependencies ++= Seq( "com.github.jsqlparser" % "jsqlparser" % "0.9.6", "org.scalamacros" %% "resetallattrs" % "1.0.0", - "com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.7.2", - "com.github.pathikrit" %% "better-files" % "2.15.0", + "com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.7.2", "org.scala-lang" % "scala-reflect" % scalaVersion.value, "org.scala-lang" % "scala-compiler" % scalaVersion.value % "provided", "org.scalatest" %% "scalatest" % "2.2.1" % "test" @@ -26,6 +25,8 @@ publishTo <<= version { (v: String) => scalacOptions := Seq("-deprecation") +//unmanagedClasspath in Compile += baseDirectory.value / "src" / "main" / "resources" + publishArtifact in Test := false pomIncludeRepository := { _ => false } diff --git a/src/main/scala/com/github/takezoe/scala/jdbc/DB.scala b/src/main/scala/com/github/takezoe/scala/jdbc/DB.scala index f6910c2..e2901cf 100644 --- a/src/main/scala/com/github/takezoe/scala/jdbc/DB.scala +++ b/src/main/scala/com/github/takezoe/scala/jdbc/DB.scala @@ -2,6 +2,7 @@ package com.github.takezoe.scala.jdbc import java.sql._ import scala.reflect.ClassTag +import IOUtils._ object DB { @@ -23,19 +24,12 @@ class DB(conn: Connection, typeMapper: TypeMapper){ def selectFirst[T](template: SqlTemplate)(f: ResultSet => T): Option[T] = { execute(conn, template){ stmt => - try { - val rs = stmt.executeQuery() - try { - if(rs.next){ - Some(f(rs)) - } else { - None - } - } finally { - rs.close() + using(stmt.executeQuery()){ rs => + if(rs.next){ + Some(f(rs)) + } else { + None } - } finally { - stmt.close() } } } @@ -108,19 +102,12 @@ class DB(conn: Connection, typeMapper: TypeMapper){ def select[T](template: SqlTemplate)(f: ResultSet => T): Seq[T] = { execute(conn, template){ stmt => - try { - val rs = stmt.executeQuery() - try { - val list = new scala.collection.mutable.ListBuffer[T] - while(rs.next){ - list += f(rs) - } - list.toSeq - } finally { - rs.close() + using(stmt.executeQuery()){ rs => + val list = new scala.collection.mutable.ListBuffer[T] + while(rs.next){ + list += f(rs) } - } finally { - stmt.close() + list.toSeq } } } @@ -201,17 +188,10 @@ class DB(conn: Connection, typeMapper: TypeMapper){ def scan[T](template: SqlTemplate)(f: ResultSet => Unit): Unit = { execute(conn, template){ stmt => - try { - val rs = stmt.executeQuery() - try { - while(rs.next){ - f(rs) - } - } finally { - rs.close() + using(stmt.executeQuery()){ rs => + while(rs.next){ + f(rs) } - } finally { - stmt.close() } } } @@ -298,7 +278,7 @@ class DB(conn: Connection, typeMapper: TypeMapper){ r } catch { case e: Throwable => - conn.rollback() + rollbackQuietly(conn) throw e } } @@ -306,14 +286,11 @@ class DB(conn: Connection, typeMapper: TypeMapper){ def close(): Unit = conn.close() protected def execute[T](conn: Connection, template: SqlTemplate)(f: (PreparedStatement) => T): T = { - val stmt = conn.prepareStatement(template.sql) - try { + using(conn.prepareStatement(template.sql)){ stmt => template.params.zipWithIndex.foreach { case (x, i) => typeMapper.set(stmt, i + 1, x) } f(stmt) - } finally { - stmt.close() } } diff --git a/src/main/scala/com/github/takezoe/scala/jdbc/IOUtils.scala b/src/main/scala/com/github/takezoe/scala/jdbc/IOUtils.scala new file mode 100644 index 0000000..ebafb2e --- /dev/null +++ b/src/main/scala/com/github/takezoe/scala/jdbc/IOUtils.scala @@ -0,0 +1,46 @@ +package com.github.takezoe.scala.jdbc + +import java.io.{ByteArrayOutputStream, InputStream} +import java.sql.Connection + +object IOUtils { + + def closeQuietly(closeable: AutoCloseable): Unit = { + if(closeable != null){ + try { + closeable.close() + } catch { + case e: Exception => // Ignore + } + } + } + + def rollbackQuietly(conn: Connection): Unit = { + try { + conn.rollback() + } catch { + case e: Exception => e.printStackTrace() + } + } + + def using[T <: AutoCloseable, R](closeable: T)(f: T => R): R = { + try { + f(closeable) + } finally { + closeQuietly(closeable) + } + } + + + def readStreamAsString(in: InputStream): String = { + val buf = new Array[Byte](1024 * 8) + var length = 0 + using(new ByteArrayOutputStream()) { out => + while ({ length = in.read(buf); length } != -1) { + out.write(buf, 0, length) + } + new String(out.toByteArray, "UTF-8") + } + } + +} diff --git a/src/main/scala/com/github/takezoe/scala/jdbc/package.scala b/src/main/scala/com/github/takezoe/scala/jdbc/package.scala index f89c21a..4a91194 100644 --- a/src/main/scala/com/github/takezoe/scala/jdbc/package.scala +++ b/src/main/scala/com/github/takezoe/scala/jdbc/package.scala @@ -39,28 +39,33 @@ object Macros { import c.universe._ sql.tree match { case Literal(x) => x.value match { - case sql: String => SqlValidator.validateSql(sql, c) + case sql: String => SqlValidator.validateSql(sql, Nil, c) val Apply(fun, _) = reify(new SqlTemplate("")).tree c.Expr[com.github.takezoe.scala.jdbc.SqlTemplate](Apply.apply(fun, Literal(x) :: Nil)) } case Apply(Select(Apply(Select(Select((_, _)), _), trees), _), args) => { val sql = trees.collect { case Literal(x) => x.value.asInstanceOf[String] }.mkString("?") - SqlValidator.validateSql(sql, c) + SqlValidator.validateSql(sql, args.map(_.tpe.toString), c) val Apply(fun, _) = reify(new SqlTemplate("")).tree + + args.foreach { arg => + println(arg.tpe.getClass) + } + c.Expr[SqlTemplate](Apply.apply(fun, Literal(Constant(sql)) :: args)) } case Select(Apply(Select(a, b), List(Literal(x))), TermName("stripMargin")) => { x.value match { case s: String => val sql = s.stripMargin - SqlValidator.validateSql(sql, c) + SqlValidator.validateSql(sql, Nil, c) val Apply(fun, _) = reify(new SqlTemplate("")).tree c.Expr[SqlTemplate](Apply.apply(fun, Literal(Constant(sql)) :: Nil)) } } case Select(Apply(_, List(Apply(Select(Apply(Select(Select((_, _)), _), trees), _), args))), TermName("stripMargin")) => { val sql = trees.collect { case Literal(x) => x.value.asInstanceOf[String] }.mkString("?").stripMargin - SqlValidator.validateSql(sql, c) + SqlValidator.validateSql(sql, args.map(_.tpe.toString), c) val Apply(fun, _) = reify(new SqlTemplate("")).tree c.Expr[SqlTemplate](Apply.apply(fun, Literal(Constant(sql)) :: args)) } diff --git a/src/main/scala/com/github/takezoe/scala/jdbc/validation/SchemaDef.scala b/src/main/scala/com/github/takezoe/scala/jdbc/validation/SchemaDef.scala index fd1abd7..3167bee 100644 --- a/src/main/scala/com/github/takezoe/scala/jdbc/validation/SchemaDef.scala +++ b/src/main/scala/com/github/takezoe/scala/jdbc/validation/SchemaDef.scala @@ -1,10 +1,18 @@ package com.github.takezoe.scala.jdbc.validation -import better.files.File +import java.io.{File, FileInputStream} + +import com.github.takezoe.scala.jdbc.IOUtils._ import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper} import com.fasterxml.jackson.module.scala.DefaultScalaModule -case class SchemaDef(tables: Seq[TableDef]) +case class SchemaDef(tables: Seq[TableDef], connection: Option[ConnectionDef]){ + def toMap: Map[String, TableDef] = { + tables.map { t => t.name -> t }.toMap + } +} + +case class ConnectionDef(driver: String, url: String, user: String, password: String) case class TableDef(name:String, columns: Seq[ColumnDef]) @@ -17,16 +25,23 @@ object SchemaDef { mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) mapper.registerModule(DefaultScalaModule) - def load(): Map[String, TableDef] = { - val file = File("schema.json") - val schema: Map[String, TableDef] = if(file.exists){ - val json = file.contentAsString - val schema = mapper.readValue(json, classOf[SchemaDef]) - schema.tables.map { t => t.name -> t }.toMap + def load(): Option[SchemaDef] = { + val file = new File("schema.json") + if(file.exists){ + // Load from file system + val json = using(new FileInputStream(file)){ in => + readStreamAsString(in) + } + Some(mapper.readValue(json, classOf[SchemaDef])) } else { - Map.empty + val in = Thread.currentThread.getContextClassLoader.getResourceAsStream("schema.json") + Option(in).map { in => + // Load from classpath + val json = using(in){ in => + readStreamAsString(in) + } + mapper.readValue(json, classOf[SchemaDef]) + } } - schema } - } \ No newline at end of file diff --git a/src/main/scala/com/github/takezoe/scala/jdbc/validation/SqlValidator.scala b/src/main/scala/com/github/takezoe/scala/jdbc/validation/SqlValidator.scala index 5b13dee..693eaf8 100644 --- a/src/main/scala/com/github/takezoe/scala/jdbc/validation/SqlValidator.scala +++ b/src/main/scala/com/github/takezoe/scala/jdbc/validation/SqlValidator.scala @@ -1,5 +1,7 @@ package com.github.takezoe.scala.jdbc.validation +import java.sql.{Date, DriverManager, Time, Timestamp} + import net.sf.jsqlparser.JSQLParserException import net.sf.jsqlparser.parser.CCJSqlParserUtil import net.sf.jsqlparser.statement.StatementVisitorAdapter @@ -9,28 +11,87 @@ import net.sf.jsqlparser.statement.update.Update import scala.reflect.macros.blackbox.Context +import com.github.takezoe.scala.jdbc.IOUtils._ +import com.github.takezoe.scala.jdbc.TypeMapper + object SqlValidator { - def validateSql(sql: String, c: Context): Unit = { - val schema = SchemaDef.load() - try { - val parse = CCJSqlParserUtil.parse(sql) - parse.accept(new StatementVisitorAdapter { - override def visit(select: net.sf.jsqlparser.statement.select.Select): Unit = { - new SelectValidator(c, select, schema).validate() - } - override def visit(insert: Insert): Unit = { - new InsertValidator(c, insert, schema).validate() + val typeMapper = new TypeMapper() // TODO It should be replaceable. + + def validateSql(sql: String, types: Seq[String], c: Context): Unit = { + SchemaDef.load() match { + case None => { + try { + CCJSqlParserUtil.parse(sql) + } catch { + case e: JSQLParserException => c.error(c.enclosingPosition, e.getCause.getMessage) } - override def visit(update: Update): Unit = { - new UpdateValidator(c, update, schema).validate() + } + case Some(SchemaDef(_, Some(connection))) => { + Class.forName(connection.driver) + val conn = DriverManager.getConnection(connection.url, connection.user, connection.password) + try { + conn.setAutoCommit(false) + using(conn.prepareStatement(adjustSql(sql))){ stmt => + try { + types.zipWithIndex.foreach { case (t, i) => + typeMapper.set(stmt, i + 1, getTestValue(t)) + } + stmt.execute() + } catch { + case e: Exception => c.error(c.enclosingPosition, e.toString) + } + } + } finally { + rollbackQuietly(conn) + closeQuietly(conn) } - override def visit(delete: Delete): Unit = { - new DeleteValidator(c, delete, schema).validate() + } + case Some(schemaDef) => { + try { + val parse = CCJSqlParserUtil.parse(sql) + val schema = schemaDef.toMap + parse.accept(new StatementVisitorAdapter { + override def visit(select: net.sf.jsqlparser.statement.select.Select): Unit = { + new SelectValidator(c, select, schema).validate() + } + override def visit(insert: Insert): Unit = { + new InsertValidator(c, insert, schema).validate() + } + override def visit(update: Update): Unit = { + new UpdateValidator(c, update, schema).validate() + } + override def visit(delete: Delete): Unit = { + new DeleteValidator(c, delete, schema).validate() + } + }) + } catch { + case e: JSQLParserException => c.error(c.enclosingPosition, e.getCause.getMessage) } - }) - } catch { - case e: JSQLParserException => c.error(c.enclosingPosition, e.getCause.getMessage) + } + } + } + + private def adjustSql(sql: String): String = { + if(sql.trim.toUpperCase.startsWith("SELECT")){ + sql + " LIMIT 0" + } else { + sql + } + } + + // TODO Move to TypeMapper? + private def getTestValue(t: String): Any = { + t match { + case "Int" => 0 + case "Long" => 0L + case "Double" => 0D + case "Short" => 0:Short + case "Float" => 0F + case "java.sql.Timestamp" => new Timestamp(System.currentTimeMillis) + case "java.sql.Date" => new Date(System.currentTimeMillis) + case "java.sql.Time" => new Time(System.currentTimeMillis) + case "String" => "-" } }