-
-
Notifications
You must be signed in to change notification settings - Fork 609
/
TestCodeGenerator.scala
137 lines (121 loc) · 4.64 KB
/
TestCodeGenerator.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package com.typesafe.slick.testkit.util
import scala.concurrent.Await
import scala.concurrent.duration.Duration
import scala.concurrent.ExecutionContext.Implicits.global
import scala.io.{Codec, Source}
import java.util.concurrent.ExecutionException
import slick.codegen.{OutputHelpers, SourceCodeGenerator}
import slick.dbio._
import slick.model.Model
import org.junit.Test
trait TestCodeGenerator {
def packageName: String
def defaultTestCode(c: Config): String
def configurations: Seq[Config]
def computeFullTdbName(tdbName: String) = StandardTestDBs.getClass.getName.replaceAll("\\$", "") + "." + tdbName
def main(args: Array[String]): Unit = try {
val clns = configurations.flatMap(_.generate(args(0)).toSeq)
new OutputHelpers {
def indent(code: String): String = code
def code: String = ""
def codePerTable:Map[String,String] = Map()
def foreignKeysPerTable: Map[String, List[String]] = Map()
def codeForContainer:String = ""
}.writeStringToFile(
s"""
|package $packageName
|object AllTests extends com.typesafe.slick.testkit.util.TestCodeRunner.AllTests {
| val clns = Seq(${clns.map("\"" + _ + "\"").mkString(", ")})
|}
""".stripMargin, args(0), packageName, "AllTests.scala"
)
} catch { case ex: Throwable =>
ex.printStackTrace(System.err)
System.exit(1)
}
class Config(val objectName: String, val tdb: JdbcTestDB, tdbName: String, initScripts: Seq[String]) { self =>
def useSingleLineStatements = false
def slickProfile = tdb.profile.getClass.getName.replaceAll("\\$", "")
def fullTdbName = computeFullTdbName(tdbName)
def generate(dir: String): Option[String] = if(tdb.isEnabled || tdb.isInstanceOf[InternalJdbcTestDB]) {
tdb.cleanUpBefore()
try {
var init: DBIO[Any] = DBIO.successful(())
var current: String = null
initScripts.foreach { initScript =>
import tdb.profile.api._
Source.fromURL(self.getClass.getResource(initScript))(Codec.UTF8).getLines().foreach { s =>
if(current eq null) current = s else current = current + "\n" + s
if(s.trim.endsWith(";")) {
if(useSingleLineStatements) {
current = current.substring(0, current.length-1)
current = current.replace("\r", "").replace('\n', ' ')
}
init = init >> sqlu"#$current"
current = null
}
}
if(current ne null) {
if(useSingleLineStatements) current = current.replace("\r", "").replace('\n', ' ')
init = init >> sqlu"#$current"
}
}
val db = tdb.createDB()
try {
val m = Await.result(db.run((init >> generator).withPinnedSession), Duration.Inf)
m.writeToFile(profile=slickProfile, folder=dir, pkg=packageName, objectName, fileName=objectName+".scala" )
} finally db.close
}
finally tdb.cleanUpAfter()
Some(s"$packageName.$objectName")
} else None
def generator: DBIO[SourceCodeGenerator] =
tdb.profile.createModel(ignoreInvalidDefaults=false).map(new MyGen(_))
def testCode: String = defaultTestCode(this)
class MyGen(model:Model) extends SourceCodeGenerator(model) {
override def entityName = sqlName => {
val baseName = super.entityName(sqlName)
if(baseName.dropRight(3).last == 's') baseName.dropRight(4)
else baseName
}
override def parentType = Some("com.typesafe.slick.testkit.util.TestCodeRunner.TestCase")
override def code = {
s"""
|lazy val tdb = $fullTdbName
|def test = {
| import org.junit.Assert._
| import scala.concurrent.ExecutionContext.Implicits.global
| $testCode
|}
|""".stripMargin + super.code
}
}
}
}
class TestCodeRunner(tests: TestCodeRunner.AllTests) {
def run(cln: String): Unit = {
val t = Class.forName(cln+"$").getField("MODULE$").get(null).asInstanceOf[TestCodeRunner.TestCase]
val tdb = t.tdb
println(s"Running test $cln on ${tdb.confName}")
if(tdb.isEnabled) {
tdb.cleanUpBefore()
try {
val a = t.test
val db = tdb.createDB()
try Await.result(db.run(a.withPinnedSession), Duration.Inf)
catch { case e: ExecutionException => throw e.getCause }
finally db.close()
} finally tdb.cleanUpAfter()
} else println("- Test database is disabled")
}
@Test def allTests = tests.clns.foreach(run)
}
object TestCodeRunner {
trait AllTests {
def clns: Seq[String]
}
trait TestCase {
def test: slick.dbio.DBIO[Any]
def tdb: JdbcTestDB
}
}