From 24770deae18b20771ea94c0d4c3ca06154ccd338 Mon Sep 17 00:00:00 2001 From: Stefan Zeiger Date: Wed, 11 Feb 2015 10:35:34 +0100 Subject: [PATCH] Add `DatabaseConfig` as a higher-level configuration abstraction: - DatabaseConfig encapsulates a Slick driver, a database configuration (as in Database.forConfig) and optionally extra configuration parameters. - SourceCodeGenerator can use DatabaseConfig for all configuration parameters. - The StaticDatabaseConfig annotation provides a configuration which is known at compile-time. It is used by the `tsql` interpolator. Tests in TypedStaticQueryTest. I could not find a way to get application.conf on the compiler class path, therefore it is referenced via a relative file: URL. --- common-test-resources/application.conf | 10 + reference.conf | 9 - .../slick/codegen/SourceCodeGenerator.scala | 96 +++-- .../test/jdbc/TypedStaticQueryTest.scala | 248 ++++++------ .../slick/backend/DatabaseComponent.scala | 10 + .../scala/slick/backend/DatabaseConfig.scala | 128 ++++++ .../scala/slick/driver/JdbcProfile.scala | 4 +- .../scala/scala/slick/jdbc/JdbcBackend.scala | 5 +- .../scala/slick/jdbc/MacroTreeBuilder.scala | 202 ++++++++++ .../scala/scala/slick/jdbc/StaticQuery.scala | 53 +-- .../scala/slick/jdbc/TypedStaticQuery.scala | 378 ------------------ .../slick/memory/DistributedBackend.scala | 4 + .../scala/slick/memory/HeapBackend.scala | 3 + 13 files changed, 584 insertions(+), 566 deletions(-) delete mode 100644 reference.conf create mode 100644 src/main/scala/scala/slick/backend/DatabaseConfig.scala create mode 100644 src/main/scala/scala/slick/jdbc/MacroTreeBuilder.scala delete mode 100644 src/main/scala/scala/slick/jdbc/TypedStaticQuery.scala diff --git a/common-test-resources/application.conf b/common-test-resources/application.conf index 1ed66e3732..a9117ce0e2 100644 --- a/common-test-resources/application.conf +++ b/common-test-resources/application.conf @@ -3,3 +3,13 @@ slick { unicodeDump = true sqlIndent = true } + +tsql { + driver = "scala.slick.driver.H2Driver$" + db { + connectionPool = disabled + driver = "org.h2.Driver" + url = "jdbc:h2:mem:tsqltest;INIT=runscript from 'slick-testkit/src/codegen/resources/dbs/tsql-test.sql'" + keepAliveConnection = true + } +} diff --git a/reference.conf b/reference.conf deleted file mode 100644 index 04465b7bb9..0000000000 --- a/reference.conf +++ /dev/null @@ -1,9 +0,0 @@ -typedsql { - -default { - url = "jdbc:h2:mem:test1;INIT=runscript from 'slick-testkit/src/codegen/resources/dbs/tsql-test.sql'" - jdbcDriver = "org.h2.Driver" - slickDriver = "scala.slick.driver.H2Driver" -} - -} \ No newline at end of file diff --git a/slick-codegen/src/main/scala/scala/slick/codegen/SourceCodeGenerator.scala b/slick-codegen/src/main/scala/scala/slick/codegen/SourceCodeGenerator.scala index d4a2a9903d..60abe2fdff 100644 --- a/slick-codegen/src/main/scala/scala/slick/codegen/SourceCodeGenerator.scala +++ b/slick-codegen/src/main/scala/scala/slick/codegen/SourceCodeGenerator.scala @@ -1,8 +1,13 @@ package scala.slick.codegen +import java.net.URI + import scala.concurrent.{ExecutionContext, Await} import scala.concurrent.duration.Duration +import scala.slick.backend.DatabaseConfig import scala.slick.{model => m} +import scala.slick.driver.JdbcProfile +import scala.slick.util.ConfigExtensionMethods.configExtensionMethods /** * A customizable code generator for working with Slick. @@ -52,45 +57,64 @@ class SourceCodeGenerator(model: m.Model) } /** A runnable class to execute the code generator without further setup */ -object SourceCodeGenerator{ - import scala.slick.driver.JdbcProfile - def main(args: Array[String]) = { - args.toList match { - case slickDriver :: jdbcDriver :: url :: outputFolder :: pkg :: tail if tail.size == 0 || tail.size == 2 => { - val driver: JdbcProfile = - Class.forName(slickDriver + "$").getField("MODULE$").get(null).asInstanceOf[JdbcProfile] - val dbFactory = driver.api.Database - val db = (tail match{ - case user :: password :: Nil => dbFactory.forURL(url, driver = jdbcDriver, user=user, password=password, keepAliveConnection=true) - case Nil => dbFactory.forURL(url, driver = jdbcDriver) - case _ => throw new Exception("This should never happen.") - }) - try { - val m = Await.result(db.run(driver.createModel(None, false)(ExecutionContext.global).withPinnedSession), Duration.Inf) - new SourceCodeGenerator(m).writeToFile(slickDriver,outputFolder,pkg) - } finally db.close - } - case _ => { - println(""" -Usage: - SourceCodeGenerator.main(Array(slickDriver, jdbcDriver, url, outputFolder, pkg)) - SourceCodeGenerator.main(Array(slickDriver, jdbcDriver, url, outputFolder, pkg, user, password)) - -slickDriver: Fully qualified name of Slick driver class, e.g. "scala.slick.driver.H2Driver" - -jdbcDriver: Fully qualified name of jdbc driver class, e.g. "org.h2.Driver" +object SourceCodeGenerator { -url: jdbc url, e.g. "jdbc:postgresql://localhost/test" - -outputFolder: Place where the package folder structure should be put - -pkg: Scala package the generated code should be places in + def run(slickDriver: String, jdbcDriver: String, url: String, outputDir: String, pkg: String, user: Option[String], password: Option[String]): Unit = { + val driver: JdbcProfile = + Class.forName(slickDriver + "$").getField("MODULE$").get(null).asInstanceOf[JdbcProfile] + val dbFactory = driver.api.Database + val db = dbFactory.forURL(url, driver = jdbcDriver, + user = user.getOrElse(null), password = password.getOrElse(null), keepAliveConnection = true) + try { + val m = Await.result(db.run(driver.createModel(None, false)(ExecutionContext.global).withPinnedSession), Duration.Inf) + new SourceCodeGenerator(m).writeToFile(slickDriver,outputDir,pkg) + } finally db.close + } -user: database connection user name + def run(uri: URI, outputDir: Option[String]): Unit = { + val dc = DatabaseConfig.forURI[JdbcProfile](uri) + val pkg = dc.config.getString("codegen.package") + val out = outputDir.getOrElse(dc.config.getStringOr("codegen.outputDir", ".")) + val slickDriver = if(dc.driverIsObject) dc.driverName else "new " + dc.driverName + try { + val m = Await.result(dc.db.run(dc.driver.createModel(None, false)(ExecutionContext.global).withPinnedSession), Duration.Inf) + new SourceCodeGenerator(m).writeToFile(slickDriver, out, pkg) + } finally dc.db.close + } -password: database connection password - """.trim - ) + def main(args: Array[String]): Unit = { + args.toList match { + case uri :: Nil => + run(new URI(uri), None) + case uri :: outputDir :: Nil => + run(new URI(uri), Some(outputDir)) + case slickDriver :: jdbcDriver :: url :: outputDir :: pkg :: Nil => + run(slickDriver, jdbcDriver, url, outputDir, pkg, None, None) + case slickDriver :: jdbcDriver :: url :: outputDir :: pkg :: user :: password :: Nil => + run(slickDriver, jdbcDriver, url, outputDir, pkg, Some(user), Some(password)) + case _ => { + println(""" + |Usage: + | SourceCodeGenerator configURI [outputDir] + | SourceCodeGenerator slickDriver jdbcDriver url outputDir pkg [user password] + | + |Options: + | configURI: A URL pointing to a standard database config file (a fragment is + | resolved as a path in the config), or just a fragment used as a path in + | application.conf on the class path + | slickDriver: Fully qualified name of Slick driver class, e.g. "scala.slick.driver.H2Driver" + | jdbcDriver: Fully qualified name of jdbc driver class, e.g. "org.h2.Driver" + | url: JDBC URL, e.g. "jdbc:postgresql://localhost/test" + | outputDir: Place where the package folder structure should be put + | pkg: Scala package the generated code should be places in + | user: database connection user name + | password: database connection password + | + |When using a config file, in addition to the standard config parameters from + |scala.slick.backend.DatabaseConfig you can set "codegen.package" and + |"codegen.outputDir". The latter can be overridden on the command line. + """.stripMargin.trim) + System.exit(1) } } } diff --git a/slick-testkit/src/test/scala/scala/slick/test/jdbc/TypedStaticQueryTest.scala b/slick-testkit/src/test/scala/scala/slick/test/jdbc/TypedStaticQueryTest.scala index 7b0b46ed9a..7d9bb8d891 100644 --- a/slick-testkit/src/test/scala/scala/slick/test/jdbc/TypedStaticQueryTest.scala +++ b/slick-testkit/src/test/scala/scala/slick/test/jdbc/TypedStaticQueryTest.scala @@ -2,140 +2,155 @@ package scala.slick.test.jdbc import org.junit.Test import org.junit.Assert._ +import scala.concurrent.Await +import scala.concurrent.duration.Duration +import scala.concurrent.ExecutionContext.Implicits.global +import scala.slick.backend.{DatabaseConfig, StaticDatabaseConfig} import scala.slick.collection.heterogenous.HNil import scala.slick.collection.heterogenous.syntax._ -import scala.slick.dbio.DBIO -import scala.slick.driver.JdbcDriver.api._ +import scala.slick.driver.JdbcProfile -@TSQLConfig("default") +@StaticDatabaseConfig("file:common-test-resources/application.conf#tsql") class TypedStaticQueryTest { - import scala.concurrent.ExecutionContext.Implicits.global - - val config = TypedStaticQuery.getConfigHandler() @Test - def testTypedInterpolation: Unit = config.connection run { - val id1 = 150 - val id2 = 1 - val s1 = tsql"select * from SUPPLIERS where SUP_ID = ${id1}" - val s2 = tsql"select * from COFFEES where SUP_ID = ${id2}" - assertEquals("select * from SUPPLIERS where SUP_ID = ?", s1.statements.head) - assertEquals("select * from COFFEES where SUP_ID = ?", s2.statements.head) - - val (total1, sales1) = (5, 4) - val s3 = tsql"select COF_NAME from COFFEES where SALES = ${sales1} and TOTAL = ${total1}" - - val s4 = tsql"select 1, '2', 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23" - - DBIO.seq( - s1.map { list1 => - val typedList1: Vector[(Int, String, String, String, String, String)] = list1 - assertEquals(Vector((150, "The High Ground", "100 Coffee Lane", "Meadows", "CA", "93966")), typedList1) - }, - s2.map { list2 => - val typedList2: Vector[(String, Int, Double, Int, Int)] = list2 - assertEquals(Vector(("coffee", 1, 2.3, 4, 5)), typedList2) - }, - s3.map { list3 => - val cof1: String = list3.head - assertEquals("coffee", cof1) - }, - s4.map { list4 => - val hlist1Typed: Int :: String :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: HNil = list4.head - assertEquals(1 :: "2" :: 3 :: 4 :: 5 :: 6 :: 7 :: 8 :: 9 :: 10 :: 11 :: 12 :: 13 :: 14 :: 15 :: 16 :: 17 :: 18 :: 19 :: 20 :: 21 :: 22 :: 23 :: HNil, hlist1Typed) - } - ) + def testTypedInterpolation: Unit = { + val dc = DatabaseConfig.forAnnotation[JdbcProfile] + import dc.driver.api._ + try { + val id1 = 150 + val id2 = 1 + val s1 = tsql"select * from SUPPLIERS where SUP_ID = ${id1}" + val s2 = tsql"select * from COFFEES where SUP_ID = ${id2}" + assertEquals("select * from SUPPLIERS where SUP_ID = ?", s1.statements.head) + assertEquals("select * from COFFEES where SUP_ID = ?", s2.statements.head) + + val (total1, sales1) = (5, 4) + val s3 = tsql"select COF_NAME from COFFEES where SALES = ${sales1} and TOTAL = ${total1}" + + val s4 = tsql"select 1, '2', 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23" + + Await.result(dc.db.run(DBIO.seq( + s1.map { list1 => + val typedList1: Vector[(Int, String, String, String, String, String)] = list1 + assertEquals(Vector((150, "The High Ground", "100 Coffee Lane", "Meadows", "CA", "93966")), typedList1) + }, + s2.map { list2 => + val typedList2: Vector[(String, Int, Double, Int, Int)] = list2 + assertEquals(Vector(("coffee", 1, 2.3, 4, 5)), typedList2) + }, + s3.map { list3 => + val cof1: String = list3.head + assertEquals("coffee", cof1) + }, + s4.map { list4 => + val hlist1Typed: Int :: String :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: Int :: HNil = list4.head + assertEquals(1 :: "2" :: 3 :: 4 :: 5 :: 6 :: 7 :: 8 :: 9 :: 10 :: 11 :: 12 :: 13 :: 14 :: 15 :: 16 :: 17 :: 18 :: 19 :: 20 :: 21 :: 22 :: 23 :: HNil, hlist1Typed) + } + )), Duration.Inf) + } finally dc.db.close() } @Test - def testCustomTypes: Unit = config.connection run { - import scala.slick.jdbc.SetParameter - - case class Foo(intVal: Int) - case class Bar(strVal: String) - - implicit val SetFoo = SetParameter[Foo] { (i, pp) => - SetParameter.SetInt(i.intVal, pp) - } - implicit val SetBar = SetParameter[Bar] { (s, pp) => - SetParameter.SetString(s.strVal, pp) - } - - val foo = new Foo(150) - val bar = new Bar("Something") - val num = 15 - - val s1 = tsql"select * from SUPPLIERS where SUP_ID = ${foo}" - val s2 = tsql"select * from SUPPLIERS where SUP_ID = ${num * 10}" - val s3 = tsql"select SUP_ID from SUPPLIERS" - val s4 = tsql"select CITY from SUPPLIERS" - - DBIO.seq( - s1.map { o1 => - val t1: Vector[(Int, String, String, String, String, String)] = o1 - assertEquals(Vector((150, "The High Ground", "100 Coffee Lane", "Meadows", "CA", "93966")), t1) - }, - s2.map { o2 => - val t2: Vector[(Int, String, String, String, String, String)] = o2 - assertEquals(Vector((150, "The High Ground", "100 Coffee Lane", "Meadows", "CA", "93966")), t2) - }, - s3.map { o3 => - val t3: Vector[Foo] = o3.map(Foo(_)) - assertEquals(Vector(Foo(101), Foo(150), Foo(49)), t3) - }, - s4.map { o4 => - val t4: Vector[Bar] = o4.map(Bar(_)) - assertEquals(List(Bar("Groundsville"), Bar("Meadows"), Bar("Mendocino")), t4) + def testCustomTypes: Unit = { + val dc = DatabaseConfig.forAnnotation[JdbcProfile] + import dc.driver.api._ + try { + import scala.slick.jdbc.SetParameter + + case class Foo(intVal: Int) + case class Bar(strVal: String) + + implicit val SetFoo = SetParameter[Foo] { (i, pp) => + SetParameter.SetInt(i.intVal, pp) } - ) + implicit val SetBar = SetParameter[Bar] { (s, pp) => + SetParameter.SetString(s.strVal, pp) + } + + val foo = new Foo(150) + val bar = new Bar("Something") + val num = 15 + + val s1 = tsql"select * from SUPPLIERS where SUP_ID = ${foo}" + val s2 = tsql"select * from SUPPLIERS where SUP_ID = ${num * 10}" + val s3 = tsql"select SUP_ID from SUPPLIERS" + val s4 = tsql"select CITY from SUPPLIERS" + + Await.result(dc.db.run(DBIO.seq( + s1.map { o1 => + val t1: Vector[(Int, String, String, String, String, String)] = o1 + assertEquals(Vector((150, "The High Ground", "100 Coffee Lane", "Meadows", "CA", "93966")), t1) + }, + s2.map { o2 => + val t2: Vector[(Int, String, String, String, String, String)] = o2 + assertEquals(Vector((150, "The High Ground", "100 Coffee Lane", "Meadows", "CA", "93966")), t2) + }, + s3.map { o3 => + val t3: Vector[Foo] = o3.map(Foo(_)) + assertEquals(Vector(Foo(101), Foo(150), Foo(49)), t3) + }, + s4.map { o4 => + val t4: Vector[Bar] = o4.map(Bar(_)) + assertEquals(List(Bar("Groundsville"), Bar("Meadows"), Bar("Mendocino")), t4) + } + )), Duration.Inf) + } finally dc.db.close() } @Test - def testPreparedQueries: Unit = config.connection run { - case class Supplier(id: Int, name: String) - implicit val supplierGetter = (arg: (Int, String)) => Supplier(arg._1, arg._2) - - def supplierForID(id: Int) = - tsql"select SUP_ID, SUP_NAME from SUPPLIERS where SUP_ID = $id" - def supplierForIdAndName(id: Int, name: String) = - tsql"select SUP_ID, SUP_NAME from SUPPLIERS where SUP_ID = $id and SUP_NAME = $name" - - val s1 = supplierForID(101) - val s2 = supplierForID(49) - val s3 = supplierForIdAndName(150, "The High Ground") - val s4 = supplierForIdAndName(49, "Superior Coffee") - - DBIO.seq( - s1.map { o1 => - val t1: Supplier = o1.map(supplierGetter).head - assertEquals(Supplier(101, "Acme, Inc."), t1) - }, - s2.map { o2 => - val t2: Supplier = o2.map(supplierGetter).head - assertEquals(Supplier(49, "Superior Coffee"), t2) - }, - s3.map { o3 => - val t3: Supplier = o3.map(supplierGetter).head - assertEquals(Supplier(150, "The High Ground"), o3) - }, - s4.map { o4 => - val t4: Supplier = o4.map(supplierGetter).head - assertEquals(Supplier(49, "Superior Coffee"), t4) - } - ) + def testPreparedQueries: Unit = { + val dc = DatabaseConfig.forAnnotation[JdbcProfile] + import dc.driver.api._ + try { + case class Supplier(id: Int, name: String) + implicit val supplierGetter = (arg: (Int, String)) => Supplier(arg._1, arg._2) + + def supplierForID(id: Int) = + tsql"select SUP_ID, SUP_NAME from SUPPLIERS where SUP_ID = $id" + def supplierForIdAndName(id: Int, name: String) = + tsql"select SUP_ID, SUP_NAME from SUPPLIERS where SUP_ID = $id and SUP_NAME = $name" + + val s1 = supplierForID(101) + val s2 = supplierForID(49) + val s3 = supplierForIdAndName(150, "The High Ground") + val s4 = supplierForIdAndName(49, "Superior Coffee") + + Await.result(dc.db.run(DBIO.seq( + s1.map { o1 => + val t1: Supplier = o1.map(supplierGetter).head + assertEquals(Supplier(101, "Acme, Inc."), t1) + }, + s2.map { o2 => + val t2: Supplier = o2.map(supplierGetter).head + assertEquals(Supplier(49, "Superior Coffee"), t2) + }, + s3.map { o3 => + val t3: Supplier = o3.map(supplierGetter).head + assertEquals(Supplier(150, "The High Ground"), t3) + }, + s4.map { o4 => + val t4: Supplier = o4.map(supplierGetter).head + assertEquals(Supplier(49, "Superior Coffee"), t4) + } + )), Duration.Inf) + } finally dc.db.close() } - @Test - def testAllStatements: Unit = config.connection run { + @Test + def testAllStatements: Unit = { + val dc = DatabaseConfig.forAnnotation[JdbcProfile] + import dc.driver.api._ + try { case class Supplier(id: Int, name: String) implicit val supplierGetter = (arg: (Int, String)) => Supplier(arg._1, arg._2) - + val testUnitDML = (x: Vector[Int]) => assertEquals(1, x.head) - + val s1 = tsql"select SUP_ID, SUP_NAME from SUPPLIERS where SUP_ID = 102" val s2 = tsql"select SUP_ID, SUP_NAME from SUPPLIERS where SUP_ID = 103" - - DBIO.seq( + + Await.result(dc.db.run(DBIO.seq( tsql"""create table "SUPPLIERS2" ("SUP_ID" INTEGER NOT NULL PRIMARY KEY,"SUP_NAME" VARCHAR NOT NULL);""", tsql"""INSERT INTO SUPPLIERS VALUES(102, 'Acme, Inc. Next', '99 Market Street', 'Groundsville', 'CA', '95199');""" map testUnitDML, tsql"""INSERT INTO SUPPLIERS VALUES(103, 'Coffee Retailers Corp.', '9 Random Street', 'Ville', 'LA', '63195');""" map testUnitDML, @@ -160,6 +175,7 @@ class TypedStaticQueryTest { tsql"""DELETE FROM SUPPLIERS WHERE SUP_ID = '102';""" map testUnitDML, tsql"""DELETE FROM SUPPLIERS WHERE SUP_ID = '103';""" map testUnitDML, tsql"""drop table "SUPPLIERS2" """ - ) - } + ).withPinnedSession), Duration.Inf) + } finally dc.db.close() + } } diff --git a/src/main/scala/scala/slick/backend/DatabaseComponent.scala b/src/main/scala/scala/slick/backend/DatabaseComponent.scala index 4cf9d6b7e6..4c5e859ebe 100644 --- a/src/main/scala/scala/slick/backend/DatabaseComponent.scala +++ b/src/main/scala/scala/slick/backend/DatabaseComponent.scala @@ -2,6 +2,8 @@ package scala.slick.backend import java.util.concurrent.atomic.AtomicLong +import com.typesafe.config.Config + import scala.language.existentials import scala.concurrent.{Promise, ExecutionContext, Future} @@ -38,6 +40,14 @@ trait DatabaseComponent { self => /** The database factory */ val Database: DatabaseFactory + /** Create a Database instance through [[https://github.com/typesafehub/config Typesafe Config]]. + * The supported config keys are backend-specific. This method is used by `DatabaseConfig`. + * @param path The path in the configuration file for the database configuration, or an empty + * string for the top level of the `Config` object. + * @param config The `Config` object to read from. + */ + def createDatabase(config: Config, path: String): Database + /** A database instance to which connections can be created. */ trait DatabaseDef { this: Database => /** Create a new session. The session needs to be closed explicitly by calling its close() method. */ diff --git a/src/main/scala/scala/slick/backend/DatabaseConfig.scala b/src/main/scala/scala/slick/backend/DatabaseConfig.scala new file mode 100644 index 0000000000..4ed70851a3 --- /dev/null +++ b/src/main/scala/scala/slick/backend/DatabaseConfig.scala @@ -0,0 +1,128 @@ +package scala.slick.backend + +import scala.language.experimental.macros + +import java.net.{URL, URI} +import scala.annotation.{StaticAnnotation, Annotation} +import scala.reflect.ClassTag +import scala.reflect.macros.Context +import scala.util.control.NonFatal +import scala.slick.SlickException +import scala.slick.profile.BasicProfile +import com.typesafe.config.{ConfigFactory, Config} + +/** A configuration for a Database plus a matching Slick driver. */ +trait DatabaseConfig[P <: BasicProfile] { + /** Get the configured Database. It is instantiated lazily when this method is called for the + * first time, and must be closed after use. */ + def db: P#Backend#Database + + /** The configured driver. */ + val driver: P + + /** The raw configuration. */ + def config: Config + + /** The name of the driver class or object (without a trailing "$"). */ + def driverName: String + + /** Whether the `driverName` represents an object instead of a class. */ + def driverIsObject: Boolean +} + +object DatabaseConfig { + /** Load a driver and database configuration through + * [[https://github.com/typesafehub/config Typesafe Config]]. + * + * The following config parameters are available: + * + * + * @param path The path in the configuration file for the database configuration (e.g. `foo.bar` + * would find a driver name at config key `foo.bar.driver`) or an empty string + * for the top level of the `Config` object. + * @param config The `Config` object to read from. This defaults to the global app config + * (e.g. in `application.conf` at the root of the class path) if not specified. + */ + def forConfig[P <: BasicProfile : ClassTag](path: String, config: Config = ConfigFactory.load()): DatabaseConfig[P] = { + val n = config.getString((if(path.isEmpty) "" else path + ".") + "driver") + val untypedP = try { + (if(n.endsWith("$")) Class.forName(n).getField("MODULE$").get(null) + else Class.forName(n).newInstance()) + } catch { case NonFatal(ex) => + throw new SlickException(s"""Error getting instance of Slick driver "$n"""", ex) + } + val pClass = implicitly[ClassTag[P]].runtimeClass + if(!pClass.isInstance(untypedP)) + throw new SlickException(s"Configured Slick driver $n is not an instance of requested profile ${pClass.getName}") + val root = config + new DatabaseConfig[P] { + lazy val db: P#Backend#Database = + driver.backend.createDatabase(root, (if(path.isEmpty) "" else path + ".") + "db") + val driver: P = untypedP.asInstanceOf[P] + lazy val config: Config = if(path.isEmpty) root else root.getConfig(path) + def driverName = if(driverIsObject) n.substring(0, n.length-1) else n + def driverIsObject = n.endsWith("$") + } + } + + /** Load a driver and database configuration from the specified URI. If only a fragment name + * is given, it is resolved as a path in the global app config (e.g. in `application.conf` at + * the root of the class path), otherwise as a path in the configuration located at the URI + * without the fragment, which must be a valid URL. Without a fragment, the whole config object + * is used. */ + def forURI[P <: BasicProfile : ClassTag](uri: URI): DatabaseConfig[P] = { + val (base, path) = { + val f = uri.getRawFragment + val s = uri.toString + if(s.isEmpty) (null, "") + else if(f eq null) (s, "") + else if(s.startsWith("#")) (null, uri.getFragment) + else (s.substring(0, s.length-f.length-1), uri.getFragment) + } + val root = + if(base eq null) ConfigFactory.load() + else ConfigFactory.parseURL(new URL(base)).resolve() + forConfig[P](path, root) + } + + /** Load a driver and database configuration from the URI specified in a [[StaticDatabaseConfig]] + * annotation in the static scope of the caller. */ + def forAnnotation[P <: BasicProfile](implicit ct: ClassTag[P]): DatabaseConfig[P] = + macro StaticDatabaseConfigMacros.getImpl[P] +} + +/** An annotation for injecting a DatabaseConfig at compile time. The URI parameter must be a + * literal String. This annotation is required for providing a statically scoped database + * configuration to the `tsql` interpolator. */ +final class StaticDatabaseConfig(val uri: String) extends Annotation with StaticAnnotation + +object StaticDatabaseConfigMacros { + private[slick] def getURI(c: Context): String = { + import c.universe._ + + def findUri(ann: Seq[Tree]): Option[String] = + ann.map(c.typeCheck(_, pt = weakTypeOf[StaticDatabaseConfig], silent = true)).collectFirst { + case Apply(Select(_, _), List(Literal(Constant(uri: String)))) => uri + } + + val methConf = Option(c.enclosingMethod).filter(_ != EmptyTree).map(_.asInstanceOf[MemberDef]) + .flatMap(md => findUri(md.mods.annotations)) + val classConf = findUri(c.enclosingClass.asInstanceOf[MemberDef].mods.annotations) + methConf.orElse(classConf).getOrElse( + c.abort(c.enclosingPosition, "No @StaticDatabaseConfig annotation found in enclosing scope")) + } + + def getImpl[P <: BasicProfile : c.WeakTypeTag](c: Context)(ct: c.Expr[ClassTag[P]]): c.Expr[DatabaseConfig[P]] = { + import c.universe._ + val uri = c.Expr[String](Literal(Constant(getURI(c)))) + reify(DatabaseConfig.forURI[P](new URI(uri.splice))(ct.splice)) + } +} diff --git a/src/main/scala/scala/slick/driver/JdbcProfile.scala b/src/main/scala/scala/slick/driver/JdbcProfile.scala index febf636963..3fa88fd810 100644 --- a/src/main/scala/scala/slick/driver/JdbcProfile.scala +++ b/src/main/scala/scala/slick/driver/JdbcProfile.scala @@ -4,7 +4,7 @@ import scala.language.{implicitConversions, higherKinds} import scala.slick.ast.{BaseTypedType, Node} import scala.slick.compiler.{Phase, QueryCompiler, InsertCompiler} import scala.slick.lifted._ -import scala.slick.jdbc.{TypedStaticQuery => TSQ, _} +import scala.slick.jdbc._ import scala.slick.profile.{SqlDriver, SqlProfile, Capability} /** A profile for accessing SQL databases via JDBC. All drivers for JDBC-based databases @@ -86,8 +86,6 @@ trait JdbcProfile extends SqlProfile with JdbcActionComponent new JdbcActionExtensionMethods[E, R, S](a) implicit def actionBasedSQLInterpolation(s: StringContext) = new ActionBasedSQLInterpolation(s) - val TypedStaticQuery = TSQ - type TSQLConfig = TSQ.TSQLConfig } @deprecated("Use 'api' instead of 'simple' or 'Implicit' to import the new API", "3.0") diff --git a/src/main/scala/scala/slick/jdbc/JdbcBackend.scala b/src/main/scala/scala/slick/jdbc/JdbcBackend.scala index 0e9e4eb01e..90706c942c 100644 --- a/src/main/scala/scala/slick/jdbc/JdbcBackend.scala +++ b/src/main/scala/scala/slick/jdbc/JdbcBackend.scala @@ -31,6 +31,8 @@ trait JdbcBackend extends RelationalBackend { val Database = new DatabaseFactoryDef {} val backend: JdbcBackend = this + def createDatabase(config: Config, path: String): Database = Database.forConfig(path, config) + class DatabaseDef(val source: JdbcDataSource, val executor: AsyncExecutor) extends super.DatabaseDef { /** The DatabaseCapabilities, accessed through a Session and created by the * first Session that needs them. Access does not need to be synchronized @@ -206,7 +208,8 @@ trait JdbcBackend extends RelationalBackend { * [[SlickException]]. * * @param path The path in the configuration file for the database configuration (e.g. `foo.bar` - * would find a database URL at config key `foo.bar.url`) + * would find a database URL at config key `foo.bar.url`) or an empty string for + * the top level of the `Config` object. * @param config The `Config` object to read from. This defaults to the global app config * (e.g. in `application.conf` at the root of the class path) if not specified. * @param driver An optional JDBC driver to call directly. If this is set to a non-null value, diff --git a/src/main/scala/scala/slick/jdbc/MacroTreeBuilder.scala b/src/main/scala/scala/slick/jdbc/MacroTreeBuilder.scala new file mode 100644 index 0000000000..4b79db7c29 --- /dev/null +++ b/src/main/scala/scala/slick/jdbc/MacroTreeBuilder.scala @@ -0,0 +1,202 @@ +package scala.slick.jdbc + +import scala.language.experimental.macros + +import scala.collection.mutable.ListBuffer +import scala.reflect.ClassTag +import scala.reflect.macros.Context + +/** AST builder used by the SQL interpolation macros. */ +private[jdbc] class MacroTreeBuilder[C <: Context](val c: C)(paramsList: List[C#Expr[Any]]) { + import c.universe._ + + def abort(msg: String) = c.abort(c.enclosingPosition, msg) + + // create a list of strings passed to this interpolation + lazy val rawQueryParts: List[String] = { + //Deconstruct macro application to determine the passed string and the actual parameters + val Apply(Select(Apply(_, List(Apply(_, strArg))), _), paramList) = c.macroApplication + strArg map { + case Literal(Constant(x: String)) => x + case _ => abort("The interpolation contained something other than constants...") + } + } + + /** + * Create a Tree of the static name of a class + * eg java.lang.String becomes + * Select(Select(Select(Ident(nme.ROOTPKG), "java"), "lang"), "String") + */ + private def createClassTreeFromString(classString: String, generator: String => Name): Tree = { + val tokens = classString.split('.').toList + val packages = tokens.dropRight(1) map (newTermName(_)) + val classType = generator(tokens.last) + val firstPackage = Ident(nme.ROOTPKG) + val others = (packages :+ classType) + others.foldLeft[Tree](firstPackage)((prev, elem) => { + Select(prev, elem) + }) + } + + /** + * Creates a tree equivalent to an implicity resolution of a given type + * eg for type GetResult[Int], this function gives the tree equivalent of + * scala.Predef.implicitly[GetResult[Int]] + */ + def implicitTree(reqType: Tree, baseType: Tree) = TypeApply( + ImplicitlyTree, List(AppliedTypeTree(baseType, List(reqType))) + ) + + //Some commonly used trees that are created on demand + lazy val GetResultTypeTree = createClassTreeFromString("scala.slick.jdbc.GetResult", newTypeName(_)) + lazy val SetParameterTypeTree = createClassTreeFromString("scala.slick.jdbc.SetParameter", newTypeName(_)) + lazy val TypedStaticQueryTypeTree = createClassTreeFromString("scala.slick.jdbc.TypedStaticQuery", newTypeName(_)) + lazy val GetResultTree = createClassTreeFromString("scala.slick.jdbc.GetResult", newTermName(_)) + lazy val SetParameterTree = createClassTreeFromString("scala.slick.jdbc.SetParameter", newTermName(_)) + lazy val ImplicitlyTree = createClassTreeFromString("scala.Predef.implicitly", newTermName(_)) + lazy val HeterogenousTree = createClassTreeFromString("scala.slick.collection.heterogenous", newTermName(_)) + lazy val VectorTree = createClassTreeFromString("scala.collection.immutable.Vector", newTermName(_)) + lazy val GetNoResultTree = createClassTreeFromString("scala.slick.jdbc.TypedStaticQuery.GetNoResult", newTermName(_)) + + /** + * Creates the tree for GetResult[] of the tsql macro + */ + def rconvTree(resultTypes: Vector[ClassTag[_]]) = { + val resultTypeTrees = resultTypes.map (_.runtimeClass.getCanonicalName match { + case "int" => TypeTree(typeOf[Int]) + case "byte" => TypeTree(typeOf[Byte]) + case "long" => TypeTree(typeOf[Long]) + case "short" => TypeTree(typeOf[Short]) + case "float" => TypeTree(typeOf[Float]) + case "double" => TypeTree(typeOf[Double]) + case "boolean" => TypeTree(typeOf[Boolean]) + case x => TypeTree(c.mirror.staticClass(x).selfType) + }) + + resultTypes.size match { + case 0 => implicitTree(TypeTree(typeOf[Int]) , GetResultTypeTree) + case 1 => implicitTree(resultTypeTrees(0), GetResultTypeTree) + case n if (n <= 22) => + implicitTree(AppliedTypeTree( + Select(Select(Ident(nme.ROOTPKG), newTermName("scala")), newTypeName("Tuple" + resultTypes.size)), + resultTypeTrees.toList + ), GetResultTypeTree) + case n => + val rtypeTree = { + val zero = TypeTree(typeOf[scala.slick.collection.heterogenous.syntax.HNil]) + val :: = Select(Select(HeterogenousTree, newTermName("syntax")), newTypeName("$colon$colon")) + resultTypeTrees.foldRight[TypTree](zero) { (typ, prev) => + AppliedTypeTree(::, List(typ, prev)) + } + } + val zero = Select(HeterogenousTree, newTermName("HNil")) + val zipped = (0 until n) zip resultTypeTrees + val << = Select(Ident(newTermName("p")), newTermName("$less$less")) + Apply( + TypeApply( + Select(GetResultTree, newTermName("apply")), + List(rtypeTree) + ), + List( + Function( + List(ValDef(Modifiers(Flag.PARAM), newTermName("p"), TypeTree(), EmptyTree)), + Block( + zipped.map { tup => + val (i: Int, typ: Tree) = tup + ValDef(Modifiers(), newTermName("gr" + i), TypeTree(), implicitTree(typ, GetResultTypeTree)) + }.toList, + zipped.foldRight[Tree](zero) { (tup, prev) => + val (i: Int, typ: Tree) = tup + Block( + List(ValDef(Modifiers(), newTermName("pv" + i), TypeTree(), Apply(<<, List(Ident(newTermName("gr" + i)))))), + Apply(Select(prev, newTermName("$colon$colon")), List(Ident(newTermName("pv" + i)))) + ) + } + ) + ) + ) + ) + } + } + + /** + * Processing of the query to fill in constants and prepare + * the query to be used and a list of SetParameter[] + */ + private lazy val interpolationResultParams: (List[Tree], Tree) = { + def decode(s: String): (String, Boolean) = { + if(s.endsWith("##")) { + val (str, bool) = decode(s.substring(0, s.length-2)) + (str + "#", bool) + } else if(s.endsWith("#")) { + (s.substring(0, s.length-1), true) + } else { + (s, false) + } + } + + /** Fuse adjacent string literals */ + def fuse(l: List[Tree]): List[Tree] = l match { + case Literal(Constant(s1: String)) :: Literal(Constant(s2: String)) :: ss => fuse(Literal(Constant(s1 + s2)) :: ss) + case s :: ss => s :: fuse(ss) + case Nil => Nil + } + + if(rawQueryParts.length == 1) + (List(Literal(Constant(rawQueryParts.head))), Select(SetParameterTree, newTermName("SetUnit"))) + else { + val queryString = new ListBuffer[Tree] + val remaining = new ListBuffer[c.Expr[SetParameter[Unit]]] + paramsList.asInstanceOf[List[c.Expr[Any]]].iterator.zip(rawQueryParts.iterator).foreach { case (param, rawQueryPart) => + val (queryPart, append) = decode(rawQueryPart) + queryString.append(Literal(Constant(queryPart))) + if(append) queryString.append(param.tree) + else { + queryString.append(Literal(Constant("?"))) + remaining += c.Expr[SetParameter[Unit]] { + Apply( + Select( + implicitTree(TypeTree(param.actualType), SetParameterTypeTree), + newTermName("applied") + ), + List(param.tree) + ) + } + } + } + queryString.append(Literal(Constant(rawQueryParts.last))) + val pconv = + if(remaining.isEmpty) Select(SetParameterTree, newTermName("SetUnit")) + else Apply( + Select(SetParameterTree, newTermName("apply")), + List( + Function( + List( + ValDef(Modifiers(Flag.PARAM), newTermName("u"), TypeTree(), EmptyTree), + ValDef(Modifiers(Flag.PARAM), newTermName("pp"), TypeTree(), EmptyTree) + ), + Block( + remaining.toList map ( sp => + Apply( + Select(sp.tree, newTermName("apply")), + List(Ident(newTermName("u")), Ident(newTermName("pp"))) + ) + ), Literal(Constant(())) + ) + ) + ) + ) + (fuse(queryString.result()), pconv) + } + } + + lazy val queryParts: Tree = + Apply(Select(VectorTree, newTermName("apply")), interpolationResultParams._1) + + def staticQueryString: String = interpolationResultParams._1 match { + case Literal(Constant(s: String)) :: Nil => s + case _ => c.abort(c.enclosingPosition, "Only constant strings may be used after '#$' in 'tsql' interpolation") + } + + lazy val pconvTree: Tree = interpolationResultParams._2 +} diff --git a/src/main/scala/scala/slick/jdbc/StaticQuery.scala b/src/main/scala/scala/slick/jdbc/StaticQuery.scala index 07182e4875..8be7a1b069 100644 --- a/src/main/scala/scala/slick/jdbc/StaticQuery.scala +++ b/src/main/scala/scala/slick/jdbc/StaticQuery.scala @@ -1,17 +1,23 @@ package scala.slick.jdbc +import java.net.URI + +import com.typesafe.config.ConfigException + import scala.language.experimental.macros import scala.language.implicitConversions +import scala.reflect.ClassTag import scala.reflect.macros.Context import scala.collection.mutable.ArrayBuffer import java.sql.PreparedStatement +import scala.slick.SlickException +import scala.slick.backend.{DatabaseConfig, StaticDatabaseConfigMacros, StaticDatabaseConfig} import scala.slick.dbio.Effect +import scala.slick.driver.JdbcProfile import scala.slick.profile.SqlStreamingAction -import TypedStaticQuery.{MacroConnectionHelper, MacroTreeBuilder} - ///////////////////////////////////////////////////////////////////////////////// Invoker-based API @@ -75,8 +81,7 @@ class SQLInterpolation(val s: StringContext) extends AnyVal { object SQLInterpolation { def sqlImpl(ctxt: Context)(param: ctxt.Expr[Any]*): ctxt.Expr[SQLInterpolationResult] = { import ctxt.universe._ - val macroConnHelper = new MacroConnectionHelper(ctxt) - val macroTreeBuilder = new MacroTreeBuilder[ctxt.type](ctxt)(param.toList, macroConnHelper.rawQueryParts) + val macroTreeBuilder = new MacroTreeBuilder[ctxt.type](ctxt)(param.toList) reify { SQLInterpolationResult( ctxt.Expr[Seq[Any]] (macroTreeBuilder.queryParts).splice, @@ -87,8 +92,7 @@ object SQLInterpolation { def sqluImpl(ctxt: Context)(param: ctxt.Expr[Any]*): ctxt.Expr[StaticQuery[Unit, Int]] = { import ctxt.universe._ - val macroConnHelper = new MacroConnectionHelper(ctxt) - val macroTreeBuilder = new MacroTreeBuilder[ctxt.type](ctxt)(param.toList, macroConnHelper.rawQueryParts) + val macroTreeBuilder = new MacroTreeBuilder[ctxt.type](ctxt)(param.toList) reify { val res: SQLInterpolationResult = SQLInterpolationResult( ctxt.Expr[Seq[Any]] (macroTreeBuilder.queryParts).splice, @@ -127,8 +131,7 @@ class ActionBasedSQLInterpolation(val s: StringContext) extends AnyVal { object ActionBasedSQLInterpolation { def sqlImpl(ctxt: Context)(param: ctxt.Expr[Any]*): ctxt.Expr[SQLActionBuilder] = { import ctxt.universe._ - val macroConnHelper = new MacroConnectionHelper(ctxt) - val macroTreeBuilder = new MacroTreeBuilder[ctxt.type](ctxt)(param.toList, macroConnHelper.rawQueryParts) + val macroTreeBuilder = new MacroTreeBuilder[ctxt.type](ctxt)(param.toList) reify { SQLActionBuilder( ctxt.Expr[Seq[Any]] (macroTreeBuilder.queryParts).splice, @@ -139,8 +142,7 @@ object ActionBasedSQLInterpolation { def sqluImpl(ctxt: Context)(param: ctxt.Expr[Any]*): ctxt.Expr[SqlStreamingAction[Effect, Vector[Int], Int]] = { import ctxt.universe._ - val macroConnHelper = new MacroConnectionHelper(ctxt) - val macroTreeBuilder = new MacroTreeBuilder[ctxt.type](ctxt)(param.toList, macroConnHelper.rawQueryParts) + val macroTreeBuilder = new MacroTreeBuilder[ctxt.type](ctxt)(param.toList) reify { val res: SQLActionBuilder = SQLActionBuilder( ctxt.Expr[Seq[Any]] (macroTreeBuilder.queryParts).splice, @@ -151,23 +153,28 @@ object ActionBasedSQLInterpolation { } def tsqlImpl(ctxt: Context)(param: ctxt.Expr[Any]*): ctxt.Expr[SqlStreamingAction[Effect, Vector[Any], Any]] = { import ctxt.universe._ - val macroConnHelper = new MacroConnectionHelper(ctxt) - val macroTreeBuilder = new MacroTreeBuilder[ctxt.type](ctxt)(param.toList, macroConnHelper.rawQueryParts) - - val rTypes = macroConnHelper.configHandler.connection withSession { - _.withPreparedStatement(macroTreeBuilder.staticQueryString) { - _.getMetaData match { - case null => Vector() - case resultMeta => Vector.tabulate(resultMeta.getColumnCount) { i => - val configHandler = macroConnHelper.configHandler -// val driver = if (configHandler.slickDriver.isDefined) configHandler.SlickDriver else scala.slick.driver.JdbcDriver - val driver = scala.slick.driver.JdbcDriver - val modelBuilder = driver.createModelBuilder(Nil, true)(scala.concurrent.ExecutionContext.global) + val macroTreeBuilder = new MacroTreeBuilder[ctxt.type](ctxt)(param.toList) + + val uri = StaticDatabaseConfigMacros.getURI(ctxt) + //TODO The database configuration and connection should be cached for subsequent macro invocations + val dc = + try DatabaseConfig.forURI[JdbcProfile](new URI((uri))) catch { + case ex @ (_: ConfigException | _: SlickException) => + ctxt.abort(ctxt.enclosingPosition, s"""Cannot load @StaticDatabaseConfig("$uri"): ${ex.getMessage}""") + } + val rTypes = try { + dc.db withSession { + _.withPreparedStatement(macroTreeBuilder.staticQueryString) { + _.getMetaData match { + case null => Vector() + case resultMeta => Vector.tabulate(resultMeta.getColumnCount) { i => + val modelBuilder = dc.driver.createModelBuilder(Nil, true)(scala.concurrent.ExecutionContext.global) modelBuilder.jdbcTypeToScala(resultMeta.getColumnType(i + 1)) + } } } } - } + } finally dc.db.close() reify { val rconv = ctxt.Expr[GetResult[Any]](macroTreeBuilder.rconvTree(rTypes)).splice diff --git a/src/main/scala/scala/slick/jdbc/TypedStaticQuery.scala b/src/main/scala/scala/slick/jdbc/TypedStaticQuery.scala deleted file mode 100644 index f6d71cfbd6..0000000000 --- a/src/main/scala/scala/slick/jdbc/TypedStaticQuery.scala +++ /dev/null @@ -1,378 +0,0 @@ -package scala.slick.jdbc - -import com.typesafe.config.{ ConfigFactory, ConfigException } -import scala.annotation.{Annotation, StaticAnnotation} -import scala.collection.mutable.{ListBuffer, ArrayBuffer} -import scala.language.experimental.macros -import scala.reflect.ClassTag -import scala.reflect.macros.Context -import scala.slick.collection.heterogenous._ - -/** - * An implementation of the macros involved in the Plain SQL API - */ -object TypedStaticQuery { - - /** - * An annotation used with the tsql interpolation macro for defining - * the configuration of compile-time database connections through a config file - */ - final class TSQLConfig(val dbName: String) extends Annotation with StaticAnnotation - - /** - * The function used to fetch a ConfigHandler instance that ensures - * uniform database connections at compile-time and at run-time. - * tsql interpolation macro must always be used in conjunction with a - * ConfigHandler, more specifically a ConfigHandler.connection - */ - def getConfigHandler(): TypedStaticQuery.ConfigHandler = macro getCHimpl - - def getCHimpl(ctxt: Context)(): ctxt.Expr[TypedStaticQuery.ConfigHandler] = { - val macroConnHelper = new MacroConnectionHelper(ctxt) { - override val c: ctxt.type = ctxt - } - ctxt.Expr(macroConnHelper.configHandlerTree) - } - - /** AST builder used by the interpolation macros. */ - private[jdbc] class MacroTreeBuilder[C <: Context](val c: C)(paramsList: List[C#Expr[Any]], rawQueryParts: List[String]) { - import c.universe._ - - /** - * Create a Tree of the static name of a class - * eg java.lang.String becomes - * Select(Select(Select(Ident(nme.ROOTPKG), "java"), "lang"), "String") - */ - private def createClassTreeFromString(classString: String, generator: String => Name): Tree = { - val tokens = classString.split('.').toList - val packages = tokens.dropRight(1) map (newTermName(_)) - val classType = generator(tokens.last) - val firstPackage = Ident(nme.ROOTPKG) - val others = (packages :+ classType) - others.foldLeft[Tree](firstPackage)((prev, elem) => { - Select(prev, elem) - }) - } - - /** - * Creates a tree equivalent to an implicity resolution of a given type - * eg for type GetResult[Int], this function gives the tree equivalent of - * scala.Predef.implicitly[GetResult[Int]] - */ - def implicitTree(reqType: Tree, baseType: Tree) = TypeApply( - ImplicitlyTree, List(AppliedTypeTree(baseType, List(reqType))) - ) - - //Some commonly used trees that are created on demand - lazy val GetResultTypeTree = createClassTreeFromString("scala.slick.jdbc.GetResult", newTypeName(_)) - lazy val SetParameterTypeTree = createClassTreeFromString("scala.slick.jdbc.SetParameter", newTypeName(_)) - lazy val TypedStaticQueryTypeTree = createClassTreeFromString("scala.slick.jdbc.TypedStaticQuery", newTypeName(_)) - lazy val GetResultTree = createClassTreeFromString("scala.slick.jdbc.GetResult", newTermName(_)) - lazy val SetParameterTree = createClassTreeFromString("scala.slick.jdbc.SetParameter", newTermName(_)) - lazy val ImplicitlyTree = createClassTreeFromString("scala.Predef.implicitly", newTermName(_)) - lazy val HeterogenousTree = createClassTreeFromString("scala.slick.collection.heterogenous", newTermName(_)) - //lazy val ArrayTree = createClassTreeFromString("scala.Array", newTermName(_)) - lazy val VectorTree = createClassTreeFromString("scala.collection.immutable.Vector", newTermName(_)) - lazy val GetNoResultTree = createClassTreeFromString("scala.slick.jdbc.TypedStaticQuery.GetNoResult", newTermName(_)) - - /** - * Creates the tree for GetResult[] of the tsql macro - */ - def rconvTree(resultTypes: Vector[ClassTag[_]]) = { - val resultTypeTrees = resultTypes.map (_.runtimeClass.getCanonicalName match { - case "int" => TypeTree(typeOf[Int]) - case "byte" => TypeTree(typeOf[Byte]) - case "long" => TypeTree(typeOf[Long]) - case "short" => TypeTree(typeOf[Short]) - case "float" => TypeTree(typeOf[Float]) - case "double" => TypeTree(typeOf[Double]) - case "boolean" => TypeTree(typeOf[Boolean]) - case x => TypeTree(c.mirror.staticClass(x).selfType) - }) - - resultTypes.size match { - case 0 => implicitTree(TypeTree(typeOf[Int]) , GetResultTypeTree) - case 1 => implicitTree(resultTypeTrees(0), GetResultTypeTree) - case n if (n <= 22) => - implicitTree(AppliedTypeTree( - Select(Select(Ident(nme.ROOTPKG), newTermName("scala")), newTypeName("Tuple" + resultTypes.size)), - resultTypeTrees.toList - ), GetResultTypeTree) - case n => - val rtypeTree = { - val zero = TypeTree(typeOf[scala.slick.collection.heterogenous.syntax.HNil]) - val :: = Select(Select(HeterogenousTree, newTermName("syntax")), newTypeName("$colon$colon")) - resultTypeTrees.foldRight[TypTree](zero) { (typ, prev) => - AppliedTypeTree(::, List(typ, prev)) - } - } - val zero = Select(HeterogenousTree, newTermName("HNil")) - val zipped = (0 until n) zip resultTypeTrees - val << = Select(Ident(newTermName("p")), newTermName("$less$less")) - Apply( - TypeApply( - Select(GetResultTree, newTermName("apply")), - List(rtypeTree) - ), - List( - Function( - List(ValDef(Modifiers(Flag.PARAM), newTermName("p"), TypeTree(), EmptyTree)), - Block( - zipped.map { tup => - val (i: Int, typ: Tree) = tup - ValDef(Modifiers(), newTermName("gr" + i), TypeTree(), implicitTree(typ, GetResultTypeTree)) - }.toList, - zipped.foldRight[Tree](zero) { (tup, prev) => - val (i: Int, typ: Tree) = tup - Block( - List(ValDef(Modifiers(), newTermName("pv" + i), TypeTree(), Apply(<<, List(Ident(newTermName("gr" + i)))))), - Apply(Select(prev, newTermName("$colon$colon")), List(Ident(newTermName("pv" + i)))) - ) - } - ) - ) - ) - ) - } - } - - /** - * Processing of the query to fill in constants and prepare - * the query to be used and a list of SetParameter[] - */ - private lazy val interpolationResultParams: (List[Tree], Tree) = { - def decode(s: String): (String, Boolean) = { - if(s.endsWith("##")) { - val (str, bool) = decode(s.substring(0, s.length-2)) - (str + "#", bool) - } else if(s.endsWith("#")) { - (s.substring(0, s.length-1), true) - } else { - (s, false) - } - } - - /** Fuse adjacent string literals */ - def fuse(l: List[Tree]): List[Tree] = l match { - case Literal(Constant(s1: String)) :: Literal(Constant(s2: String)) :: ss => fuse(Literal(Constant(s1 + s2)) :: ss) - case s :: ss => s :: fuse(ss) - case Nil => Nil - } - - if(rawQueryParts.length == 1) - (List(Literal(Constant(rawQueryParts.head))), Select(SetParameterTree, newTermName("SetUnit"))) - else { - val queryString = new ListBuffer[Tree] - val remaining = new ListBuffer[c.Expr[SetParameter[Unit]]] - paramsList.asInstanceOf[List[c.Expr[Any]]].iterator.zip(rawQueryParts.iterator).foreach { case (param, rawQueryPart) => - val (queryPart, append) = decode(rawQueryPart) - queryString.append(Literal(Constant(queryPart))) - if(append) queryString.append(param.tree) - else { - queryString.append(Literal(Constant("?"))) - remaining += c.Expr[SetParameter[Unit]] { - Apply( - Select( - implicitTree(TypeTree(param.actualType), SetParameterTypeTree), - newTermName("applied") - ), - List(param.tree) - ) - } - } - } - queryString.append(Literal(Constant(rawQueryParts.last))) - val pconv = - if(remaining.isEmpty) Select(SetParameterTree, newTermName("SetUnit")) - else Apply( - Select(SetParameterTree, newTermName("apply")), - List( - Function( - List( - ValDef(Modifiers(Flag.PARAM), newTermName("u"), TypeTree(), EmptyTree), - ValDef(Modifiers(Flag.PARAM), newTermName("pp"), TypeTree(), EmptyTree) - ), - Block( - remaining.toList map ( sp => - Apply( - Select(sp.tree, newTermName("apply")), - List(Ident(newTermName("u")), Ident(newTermName("pp"))) - ) - ), Literal(Constant(())) - ) - ) - ) - ) - (fuse(queryString.result()), pconv) - } - } - - lazy val queryParts: Tree = - Apply(Select(VectorTree, newTermName("apply")), interpolationResultParams._1) - - def staticQueryString: String = interpolationResultParams._1 match { - case Literal(Constant(s: String)) :: Nil => s - case _ => c.abort(c.enclosingPosition, "Only constant strings may be used after '#$' in 'tsql' interpolation") - } - - lazy val pconvTree: Tree = interpolationResultParams._2 - } - - /** - * Helps fetch a ConfigHandler by scanning the cache or else searching - * enclosing class and method definitions for TSQLConfig annotation and - * parsing it - */ - private[jdbc] class MacroConnectionHelper(val c: Context) { - import c.universe._ - - def abort(msg: String) = c.abort(c.enclosingPosition, msg) - - // create a list of strings passed to this interpolation - lazy val rawQueryParts: List[String] = { - //Deconstruct macro application to determine the passed string and the actual parameters - val Apply(Select(Apply(_, List(Apply(_, strArg))), _), paramList) = c.macroApplication - strArg map { - case Literal(Constant(x: String)) => x - case _ => abort("The interpolation contained something other than constants...") - } - } - - /** - * Create a ConfigHandler Expr from a TSQLConfig object, - * most probably retreived from an annotation in the client code - */ - def createConfigHandler(config: TSQLConfig) = Apply( - Select( - New( - createClassTreeFromString("scala.slick.jdbc.TypedStaticQuery.ConfigHandler", newTypeName(_)) - ), nme.CONSTRUCTOR), - List(Literal(Constant(config.dbName))) - ) - - /** - * Create a Tree of the static name of a class - * eg java.lang.String becomes - * Select(Select(Select(Ident(nme.ROOTPKG), "java"), "lang"), "String") - */ - def createClassTreeFromString(classString: String, generator: String => Name): Tree = { - val tokens = classString.split('.').toList - val packages = tokens.dropRight(1) map (newTermName(_)) - val classType = generator(tokens.last) - val firstPackage = Ident(nme.ROOTPKG) - val others = (packages :+ classType) - others.foldLeft[Tree](firstPackage)((prev, elem) => { - Select(prev, elem) - }) - } - - //Shorthand for c.eval - private[this] def eval[T](tree: Tree): T = c.eval(c.Expr[T](c.resetLocalAttrs(tree))) - - /** - * Actually locates a TSQLConfig annotation and - * creates a Tree of ConfigHandler - */ - lazy val configHandlerTree: Tree = { - - //From a list of annotations determine the TSQLConfig annotation - def findAnnotationTree(ann: List[Tree]): Option[Tree] = ann.flatMap { tree => - c.typeCheck(tree, pt = weakTypeOf[TSQLConfig], silent = true) match { - case EmptyTree => None - case _ => { - val Apply(Select(_, _), args) = tree - val realTree = Apply( - Select( - New(createClassTreeFromString("scala.slick.jdbc.TypedStaticQuery.TSQLConfig", newTypeName(_))), - nme.CONSTRUCTOR - ), args map (_.duplicate) - ) - Some(realTree) - } - } - }.headOption - - //Determine the trees - val clasDef = c.enclosingClass.asInstanceOf[MemberDef] - val methDef = Option(c.enclosingMethod).filter(_ != EmptyTree).map(_.asInstanceOf[MemberDef]) - - //Determine the annotations and evaluate corresponding ConfigHandlers - val clasConf = findAnnotationTree(clasDef.mods.annotations) map {t => - createConfigHandler(eval[TSQLConfig](t)) - } - val methConf = methDef.flatMap(md => findAnnotationTree(md.mods.annotations) map {t => - createConfigHandler(eval[TSQLConfig](t)) - }) - - methConf.getOrElse { - clasConf.getOrElse{ - abort("Cannot find suitable config handler for this invocation") - } - } - } - - /** - * Actually locates a TSQLConfig annotation and - * creates an instance of ConfigHandler - */ - lazy val configHandler: ConfigHandler = eval[ConfigHandler](configHandlerTree) - } - - /** - * The class that is used to ensure an uniform database access mechanism - * at the compile-time and run-time by factoring the connection parameters - */ - final class ConfigHandler(databaseName: String) { - - final val configFileName = "reference.conf" - final val configGlobalPrefix = "typedsql." - - private[TypedStaticQuery] final def connectionParameter[T](value: Option[T], - name: String): T = value match { - case Some(x) => x - case None => error(s"Configuration for essential parameter ${name} not found") - } - - def error(msg: String): Nothing = sys.error(msg) - - lazy private[this] val conf = { - val confFile = { - val file = new java.io.File(configFileName) - if (file.isFile() && file.exists()) - file - else - error(s"Configuration file does not exist. Create a file: ${file.getAbsolutePath}") - } - ConfigFactory.parseFile(confFile) -// ConfigFactory.load() - } - - lazy private[this] val databaseConfig: Option[String => String] = try { - Option{ _key => - val c = conf.getConfig(configGlobalPrefix + databaseName) - if (c.hasPath(_key)) c.getString(_key) else null - } - } catch { - case _: ConfigException.Missing => None - } - - //lazy val databaseName:Option[String] = None - lazy val url :Option[String] = databaseConfig.map(_.apply("url")) - lazy val user :Option[String] = databaseConfig.map(_.apply("user")) - lazy val password :Option[String] = databaseConfig.map(_.apply("password")) - lazy val jdbcDriver :Option[String] = databaseConfig.map(_.apply("jdbcDriver")) -// lazy val slickDriver :Option[String] = databaseConfig.map(_.apply("slickDriver")) - - lazy final val connection = JdbcBackend.Database.forURL(connectionParameter(url, "url"), - user = user getOrElse null, - password = password getOrElse null, - driver = connectionParameter(jdbcDriver, "driver") - ) - -// TODO: Fix this! -// lazy final val JdbcDriver = Class.forName(connectionParameter(jdbcDriver, "jdbcDriver")).asInstanceOf[java.sql.Driver] - -// lazy final val SlickDriver = Class.forName(connectionParameter(slickDriver, "slickDriver")).asInstanceOf[scala.slick.driver.JdbcDriver] - - } -} \ No newline at end of file diff --git a/src/main/scala/scala/slick/memory/DistributedBackend.scala b/src/main/scala/scala/slick/memory/DistributedBackend.scala index ea3f58fae5..6f5b297278 100644 --- a/src/main/scala/scala/slick/memory/DistributedBackend.scala +++ b/src/main/scala/scala/slick/memory/DistributedBackend.scala @@ -1,5 +1,6 @@ package scala.slick.memory +import com.typesafe.config.Config import org.reactivestreams.Subscriber import scala.concurrent.{ExecutionContext, Future, blocking} @@ -22,6 +23,9 @@ trait DistributedBackend extends RelationalBackend with Logging { val Database = new DatabaseFactoryDef val backend: DistributedBackend = this + def createDatabase(config: Config, path: String): Database = + throw new SlickException("DistributedBackend cannot be configured with an external config file") + class DatabaseDef(val dbs: Vector[DatabaseComponent#DatabaseDef], val executionContext: ExecutionContext) extends super.DatabaseDef { protected[this] def createDatabaseActionContext[T](_useSameThread: Boolean): Context = new BasicActionContext { val useSameThread = _useSameThread } diff --git a/src/main/scala/scala/slick/memory/HeapBackend.scala b/src/main/scala/scala/slick/memory/HeapBackend.scala index a9098ad394..97a73eb74d 100644 --- a/src/main/scala/scala/slick/memory/HeapBackend.scala +++ b/src/main/scala/scala/slick/memory/HeapBackend.scala @@ -1,6 +1,7 @@ package scala.slick.memory import java.util.concurrent.atomic.AtomicLong +import com.typesafe.config.Config import org.reactivestreams.Subscriber import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} @@ -25,6 +26,8 @@ trait HeapBackend extends RelationalBackend with Logging { val Database = new DatabaseFactoryDef val backend: HeapBackend = this + def createDatabase(config: Config, path: String): Database = Database.apply(ExecutionContext.global) + class DatabaseDef(protected val synchronousExecutionContext: ExecutionContext) extends super.DatabaseDef { protected[this] def createDatabaseActionContext[T](_useSameThread: Boolean): Context = new BasicActionContext { val useSameThread = _useSameThread }