Skip to content

Commit

Permalink
Get unmocked singleton objects working
Browse files Browse the repository at this point in the history
  • Loading branch information
paulbutcher committed Oct 20, 2011
1 parent 55a3115 commit d6e317d
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 13 deletions.
33 changes: 27 additions & 6 deletions compiler_plugin/src/main/scala/GenerateMocks.scala
Expand Up @@ -112,15 +112,22 @@ class GenerateMocks(plugin: BorachioPlugin, val global: Global) extends PluginCo
def mockFactory =
"package com.borachio.generated\n\n"+
"trait GeneratedMockFactory extends com.borachio.GeneratedMockFactoryBase { self: com.borachio.MockFactoryBase =>\n"+
(mocks map { case (target, mock) => mockFactoryEntry(target, mock) }).mkString("\n") +"\n\n"+
(mockObjects map { case (target, mock) => mockObjectEntry(target, mock) }).mkString("\n") +"\n\n"+
toMockMethods +"\n\n"+
mockObjectMethods +"\n"+
"}"

def toMockMethods =
(mocks map { case (target, mock) => mockFactoryEntry(target, mock) }).mkString("\n")

def mockObjectMethods =
(mockObjects map { case (target, mock) => mockObjectEntry(target, mock) }).mkString("\n")

def mockFactoryEntry(target: String, mock: String) =
" implicit def toMock(m: "+ target +") = m.asInstanceOf["+ mock +"]"

def mockObjectEntry(target: String, mock: String) = " def mockObject(x: "+ target +".type) = x.asInstanceOf["+ mock +"]"

def mockObjectEntry(target: String, mock: String) =
" def mockObject(x: "+ target +".type) = objectToMock["+ mock +"](x)"

trait Context {
val topLevel: Boolean
val fullMockTraitOrClassName: String
Expand Down Expand Up @@ -191,7 +198,7 @@ class GenerateMocks(plugin: BorachioPlugin, val global: Global) extends PluginCo
def getMock: String =
mockedTypeDeclaration +" {\n\n"+
mockClassEntries +"\n\n"+
forwardTo +"\n\n"+
forwarding +"\n\n"+
indent((nestedMocks map ( _.getMock )).mkString("\n")) +"\n\n"+
"}"

Expand Down Expand Up @@ -239,7 +246,9 @@ class GenerateMocks(plugin: BorachioPlugin, val global: Global) extends PluginCo
" method.invoke(classLoader).asInstanceOf[com.borachio.MockFactoryBase]\n"+
" }"

lazy val forwardTo = " private var forwardTo: AnyRef = _\n\n"+ forwarders
lazy val forwarding = forwardTo +"\n\n"+ forwarders

def forwardTo = " private var forwardTo: AnyRef = _"

lazy val forwarders = (methodsToMock map forwardMethod _).mkString("\n")

Expand Down Expand Up @@ -597,6 +606,18 @@ class GenerateMocks(plugin: BorachioPlugin, val global: Global) extends PluginCo

override lazy val fullClassName = qualifiedClassName

override def forwardTo = super.forwardTo +"\n\n"+ resetForwarding

lazy val resetForwarding =
" resetForwarding\n\n"+
" def resetForwarding() {\n"+
" val clazz = com.borachio.ReflectionUtilities.getUnmockedClass(getClass, \""+ fullClassName +"$\")\n"+
" forwardTo = clazz.getField(\"MODULE$\").get(null)\n"+
" }\n\n"+
" def enableForwarding() {\n"+
" forwardTo = null\n"+
" }"

override def getMockTraitOrClassName = "Mock$$"+ className

override def recordMapping(mapping: (String, String)) {
Expand Down
2 changes: 1 addition & 1 deletion compiler_plugin_tests/src/test/scala/PluginTest.scala
Expand Up @@ -213,7 +213,7 @@ class PluginTest extends FunSuite with MockFactory with GeneratedMockFactory wit
}

//! TODO
ignore("unmocked simple object") {
test("unmocked simple object") {
expect("hello") { SimpleObject.sayHello }
}

Expand Down
29 changes: 23 additions & 6 deletions core/src/main/scala/GeneratedMockFactoryBase.scala
Expand Up @@ -20,19 +20,36 @@

package com.borachio

import scala.collection.mutable.ListBuffer

trait GeneratedMockFactoryBase { self: MockFactoryBase =>

protected def mock[T: ClassManifest] = {
val constructor = classToCreate[T].getConstructor(classOf[MockConstructorDummy])
constructor.newInstance(new MockConstructorDummy).asInstanceOf[T]
}

protected def objectToMock[M](x: AnyRef) = {
x.getClass.getMethod("enableForwarding").invoke(x)
mockedObjects += x
x.asInstanceOf[M]
}

private[borachio] def classToCreate[T: ClassManifest] = {
protected override def resetMocks() {
mockedObjects foreach { m =>
m.getClass.getMethod("resetForwarding").invoke(m)
}
mockedObjects.clear
}

private def classToCreate[T: ClassManifest] = {
val erasure = classManifest[T].erasure
val clazz = Class.forName(erasure.getName)
if (clazz.isInterface)
Class.forName(erasure.getPackage.getName +".Mock$"+ erasure.getSimpleName)
else
clazz
}

protected def mock[T: ClassManifest] = {
val constructor = classToCreate[T].getConstructor(classOf[MockConstructorDummy])
constructor.newInstance(new MockConstructorDummy).asInstanceOf[T]
}

private val mockedObjects = new ListBuffer[AnyRef]
}
2 changes: 2 additions & 0 deletions core/src/main/scala/MockFactoryBase.scala
Expand Up @@ -40,6 +40,8 @@ trait MockFactoryBase {
expectationContext = null
}

protected def resetMocks() {}

protected def inAnyOrder(what: => Unit) {
inContext(new UnorderedExpectations)(what)
}
Expand Down
3 changes: 3 additions & 0 deletions frameworks/scalatest/src/main/scala/MockFactory.scala
Expand Up @@ -85,6 +85,9 @@ trait MockFactory extends AbstractSuite with MockFactoryBase { this: Suite =>
case None => throw laterException
}
}
finally {
resetMocks
}
}
}

Expand Down

0 comments on commit d6e317d

Please sign in to comment.