Skip to content
This repository has been archived by the owner on Jun 14, 2020. It is now read-only.

Commit

Permalink
allow setup, cleanup functions to access ClassLoader used for testing
Browse files Browse the repository at this point in the history
implements sbt#118
  • Loading branch information
harrah committed Oct 9, 2010
1 parent 32bfb21 commit 8f1899d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 12 deletions.
10 changes: 7 additions & 3 deletions sbt/src/main/scala/sbt/ScalaProject.scala
Expand Up @@ -69,8 +69,12 @@ trait ScalaProject extends SimpleScalaProject with FileTasks with MultiTaskProje
trait PackageOption extends ActionOption
trait TestOption extends ActionOption

case class TestSetup(setup: () => Option[String]) extends TestOption
case class TestCleanup(cleanup: () => Option[String]) extends TestOption
case class TestSetup(setup: ClassLoader => Option[String]) extends TestOption {
def this(setup: () => Option[String]) = this(_ => setup())
}
case class TestCleanup(cleanup: ClassLoader => Option[String]) extends TestOption {
def this(setup: () => Option[String]) = this(_ => setup())
}
case class ExcludeTests(tests: Iterable[String]) extends TestOption
case class TestListeners(listeners: Iterable[TestReportListener]) extends TestOption
case class TestFilter(filterTest: String => Boolean) extends TestOption
Expand Down Expand Up @@ -275,7 +279,7 @@ trait ScalaProject extends SimpleScalaProject with FileTasks with MultiTaskProje

val testFilters = new ListBuffer[String => Boolean]
val excludeTestsSet = new HashSet[String]
val setup, cleanup = new ListBuffer[() => Option[String]]
val setup, cleanup = new ListBuffer[ClassLoader => Option[String]]
val testListeners = new ListBuffer[TestReportListener]
val testArgsByFramework = Map[TestFramework, ListBuffer[String]]()
def frameworkArgs(framework: TestFramework): ListBuffer[String] =
Expand Down
18 changes: 9 additions & 9 deletions sbt/src/main/scala/sbt/TestFramework.scala
Expand Up @@ -126,16 +126,16 @@ object TestFramework
log: Logger,
listeners: Seq[TestReportListener],
endErrorsEnabled: Boolean,
setup: Iterable[() => Option[String]],
cleanup: Iterable[() => Option[String]],
setup: Iterable[ClassLoader => Option[String]],
cleanup: Iterable[ClassLoader => Option[String]],
testArgsByFramework: Map[TestFramework, Seq[String]]):
(Iterable[NamedTestTask], Iterable[NamedTestTask], Iterable[NamedTestTask]) =
{
val (loader, tempDir) = createTestLoader(classpath, scalaInstance)
val arguments = immutable.Map() ++
( for(framework <- frameworks; created <- framework.create(loader, log)) yield
(created, testArgsByFramework.getOrElse(framework, Nil)) )
val cleanTmp = () => FileUtilities.clean(tempDir, log)
val cleanTmp = (_: ClassLoader) => FileUtilities.clean(tempDir, log)

val mappedTests = testMap(arguments.keys.toList, tests, arguments)
if(mappedTests.isEmpty)
Expand All @@ -162,12 +162,12 @@ object TestFramework
assignTests()
(immutable.Map() ++ map) transform { (framework, tests) => (tests, args(framework)) }
}
private def createTasks(work: Iterable[() => Option[String]], baseName: String) =
work.toList.zipWithIndex.map{ case (work, index) => new NamedTestTask(baseName + " " + (index+1), work()) }
private def createTasks[T](work: Iterable[T => Option[String]], baseName: String, input: T) =
work.toList.zipWithIndex.map{ case (work, index) => new NamedTestTask(baseName + " " + (index+1), work(input)) }

private def createTestTasks(loader: ClassLoader, tests: Map[Framework, (Set[TestDefinition], Seq[String])], log: Logger,
listeners: Seq[TestReportListener], endErrorsEnabled: Boolean, setup: Iterable[() => Option[String]],
cleanup: Iterable[() => Option[String]]) =
listeners: Seq[TestReportListener], endErrorsEnabled: Boolean, setup: Iterable[ClassLoader => Option[String]],
cleanup: Iterable[ClassLoader => Option[String]]) =
{
val testsListeners = listeners.filter(_.isInstanceOf[TestsListener]).map(_.asInstanceOf[TestsListener])
def foreachListenerSafe(f: TestsListener => Unit): Unit = safeForeach(testsListeners, log)(f)
Expand All @@ -179,7 +179,7 @@ object TestFramework
def apply() = synchronized { value }
def update(v: Result.Value): Unit = synchronized { if(value != Error) value = v }
}
val startTask = new NamedTestTask(TestStartName, {foreachListenerSafe(_.doInit); None}) :: createTasks(setup, "Test setup")
val startTask = new NamedTestTask(TestStartName, {foreachListenerSafe(_.doInit); None}) :: createTasks(setup, "Test setup", loader)
val testTasks =
tests flatMap { case (framework, (testDefinitions, testArgs)) =>

Expand Down Expand Up @@ -220,7 +220,7 @@ object TestFramework
}
}
}
val endTask = new NamedTestTask(TestFinishName, end() ) :: createTasks(cleanup, "Test cleanup")
val endTask = new NamedTestTask(TestFinishName, end() ) :: createTasks(cleanup, "Test cleanup", loader)
(startTask, testTasks, endTask)
}
def createTestLoader(classpath: Iterable[Path], scalaInstance: ScalaInstance): (ClassLoader, Path) =
Expand Down

0 comments on commit 8f1899d

Please sign in to comment.