Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[util-test] Easy argument capture for Mockito mocks.
Problem Mockito's ArgumentCaptor enables you to capture the arguments that were passed to stubbed-out methods on mocks, to make assertions on them. However, the API for ArgumentCaptor was built for Java and is kind of clunky compared to what we expect in Scala. An alternative to using ArgumentCaptor is to use Mockito argument matchers in when(myObj.method(...)). However, this means every assertion you want to make has to be encapsulated in a custom matcher subclass, and you can't make assertions spanning more than one argument in each stubbed method invocation. For example, you might want to check that a key-value store wrapper object was called with a key of "foo" and a timeout of 10 seconds, and Mockito argument matchers would not be able to provide this. Another alternative is to use matchers in a verify(myMock, times(N)).someMethod(...) call, but this suffers from the same problem of having to write a specific matcher subclass for any non-standard assertions. Solution The ArgumentCapture trait can be mixed into your test class and has two methods, capturingOne and capturingAll, which capture the invocations of your stubbed methods as a tuple or Seq of tuples. Here's an example: when(myObj.lookup(any[String], any[Duration])).thenReturn(someValue) // The call to myObj.lookup will normally live somewhere inside the // subject under test. assert(myObj.lookup("foo", 10.seconds) === someValue) val (key, timeout) = capturingOne(verify(myObj).lookup _) assert(key === "foo") assert(timeout === 10.seconds) An example which uses capturingAll and doesn't make any assertions about the order of the stubbed method calls (forExactly is defined on the org.scalatest.Inspectors trait): val requests = capturingAll(verify(myObj, times(2)).lookup _) forExactly(1, requests) { (key, timeout) => assert(key === "foo") assert(timeout === 10.seconds) } forExactly(1, requests) { (key, timeout) => assert(key === "bar") assert(timeout === 2.seconds) } =================== #124 RB_ID=587072 TBR=true
- Loading branch information
Zachary Voase
authored and
jenkins
committed
Feb 23, 2015
1 parent
310f6e1
commit cd2c096
Showing
8 changed files
with
1,409 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
usage: | ||
@echo 'No default make target is provided.' | ||
@echo 'Run `make <filename>` to run the code generator on the given file,' | ||
@echo 'or edit the Makefile to change how the generated file is produced.' | ||
@false | ||
|
||
../util-test/src/main/scala/com/twitter/util/testing/ArgumentCapture.scala: util-test/ArgumentCapture.scala.mako | ||
mako-render util-test/ArgumentCapture.scala.mako > $@ | ||
|
||
.PHONY: usage |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# Code Generators | ||
|
||
This directory contains templates and executables for generating Scala code. We | ||
normally need to do this when defining a function or method that accepts a | ||
variable number of type parameters. For example, we may want to define a `zip` | ||
function for tuples of arbitrary size: | ||
|
||
def zip[A, B](a: Seq[A], b: Seq[B]): Seq[(A, B)] | ||
def zip[A, B, C](a: Seq[A], b: Seq[B], c: Seq[C]): Seq[(A, B, C)] | ||
def zip[A, B, C, D] // and so on... | ||
|
||
There isn't a native way to do this sort of thing in pure Scala (hence the | ||
existence of Tuple1 through Tuple22, Function1...22, etc.). | ||
|
||
## Usage | ||
|
||
A Makefile exists which can build the code-generated files. To (re-)generate a | ||
file: | ||
|
||
make ../util-something/path/to/File.scala | ||
|
||
The Mako Python templating library is used by some targets, and it can be | ||
installed with `pip`: | ||
|
||
pip install -r requirements.txt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
mako>=1.0.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
<%! | ||
TYPE_VARS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' | ||
MAX_ARITY = 22 | ||
def nested_tuples(n): | ||
""" | ||
>>> nested_tuples(2) | ||
(a, b) | ||
>>> nested_tuples(3) | ||
((a, b), c) | ||
>>> nested_tuples(5) | ||
'((((a, b), c), d), e)' | ||
""" | ||
def pairwise(elems): | ||
if not elems: return '' | ||
elif len(elems) == 1: return elems[0] | ||
else: return '({}, {})'.format(pairwise(elems[:-1]), elems[-1]) | ||
return pairwise(TYPE_VARS.lower()[:n]) | ||
def flat_tuple(n): | ||
return '({})'.format(', '.join(TYPE_VARS.lower()[:n])) | ||
%>\ | ||
package com.twitter.util.testing | ||
|
||
import java.util.{List => JList} | ||
import org.mockito.ArgumentCaptor | ||
import org.mockito.exceptions.Reporter | ||
import scala.collection.JavaConverters._ | ||
import scala.reflect._ | ||
|
||
// This file was generated from codegen/util-test/ArgumentCapture.scala.mako | ||
|
||
trait ArgumentCapture { | ||
/** | ||
* Enables capturingOne to be implemented over capturingAll with the same behavior as ArgumentCaptor.getValue | ||
*/ | ||
private[this] def noArgWasCaptured(): Nothing = { | ||
new Reporter().noArgumentValueWasCaptured() // this always throws an exception | ||
throw new RuntimeException("this should be unreachable, but allows the method to be of type Nothing") | ||
} | ||
|
||
/** | ||
* Capture all the invocations from a verify(mock).method(arg) call. | ||
* | ||
* Example: | ||
* val requests = capturingAll(verify(myAPIEndpoint, times(4)).authenticate _) | ||
* requests.length must equal (4) | ||
*/ | ||
def capturingAll[T: ClassTag](f: T => _): Seq[T] = { | ||
val argCaptor = ArgumentCaptor.forClass(classTag[T].runtimeClass.asInstanceOf[Class[T]]) | ||
f(argCaptor.capture()) | ||
argCaptor.getAllValues.asScala.toSeq | ||
} | ||
|
||
/** | ||
* Capture an argument from a verify(mock).method(arg) call. | ||
* | ||
* Example: | ||
* val request = capturingOne(verify(myAPIEndpoint).authenticate _) | ||
* request.userId must equal (123L) | ||
* request.password must equal ("reallySecurePassword") | ||
*/ | ||
def capturingOne[T: ClassTag](f: T => _): T = | ||
capturingAll[T](f).lastOption.getOrElse(noArgWasCaptured())\ | ||
|
||
% for i in xrange(2, MAX_ARITY + 1): | ||
<% | ||
types = ', '.join(TYPE_VARS[:i]) | ||
types_with_class_tags = ', '.join('{}: ClassTag'.format(t) for t in TYPE_VARS[:i]) | ||
iterable_args = ', '.join('arg{}: Iterable[{}]'.format(j, type) for j, type in enumerate(TYPE_VARS[:i])) | ||
%> | ||
/** Zip ${i} iterables together into a Seq of ${i}-tuples. */ | ||
private[this] def zipN[${types}](${iterable_args}): Seq[(${types})] = { | ||
% if i == 2: | ||
arg0.zip(arg1).toSeq | ||
% else: | ||
arg0 | ||
% for j in xrange(1, i): | ||
.zip(arg${j}) | ||
% endfor | ||
.map({ case ${nested_tuples(i)} => ${flat_tuple(i)} }) | ||
.toSeq | ||
% endif | ||
} | ||
|
||
/** Capture all invocations of a mocked ${i}-ary method */ | ||
def capturingAll[${types_with_class_tags}](func: (${types}) => _): Seq[(${types})] = { | ||
% for type in TYPE_VARS[:i]: | ||
val argCaptor${type} = ArgumentCaptor.forClass(classTag[${type}].runtimeClass.asInstanceOf[Class[${type}]]) | ||
% endfor | ||
func(${', '.join("argCaptor{}.capture()".format(type) for type in TYPE_VARS[:i])}) | ||
% for type in TYPE_VARS[:i]: | ||
val args${type} = argCaptor${type}.getAllValues.asScala | ||
% endfor | ||
zipN(${', '.join("args{}".format(type) for type in TYPE_VARS[:i])}) | ||
} | ||
|
||
/** Capture one invocation of a mocked ${i}-ary method */ | ||
def capturingOne[${types_with_class_tags}](func: (${types}) => _): (${types}) = | ||
capturingAll[${types}](func).lastOption.getOrElse(noArgWasCaptured()) | ||
% endfor | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.