Skip to content

Commit

Permalink
[util-test] Easy argument capture for Mockito mocks.
Browse files Browse the repository at this point in the history
Problem

Mockito's ArgumentCaptor enables you to capture the arguments that were
passed to stubbed-out methods on mocks, to make assertions on them.
However, the API for ArgumentCaptor was built for Java and is kind of
clunky compared to what we expect in Scala.

An alternative to using ArgumentCaptor is to use Mockito argument
matchers in when(myObj.method(...)). However, this means every
assertion you want to make has to be encapsulated in a custom matcher
subclass, and you can't make assertions spanning more than one argument
in each stubbed method invocation. For example, you might want to check
that a key-value store wrapper object was called with a key of "foo" and
a timeout of 10 seconds, and Mockito argument matchers would not be able
to provide this.

Another alternative is to use matchers in a
verify(myMock, times(N)).someMethod(...) call, but this suffers from
the same problem of having to write a specific matcher subclass for any
non-standard assertions.

Solution

The ArgumentCapture trait can be mixed into your test class and has two
methods, capturingOne and capturingAll, which capture the
invocations of your stubbed methods as a tuple or Seq of tuples. Here's
an example:

when(myObj.lookup(any[String], any[Duration])).thenReturn(someValue)

// The call to myObj.lookup will normally live somewhere inside the
// subject under test.
assert(myObj.lookup("foo", 10.seconds) === someValue)

val (key, timeout) = capturingOne(verify(myObj).lookup _)
assert(key === "foo")
assert(timeout === 10.seconds)
An example which uses capturingAll and doesn't make any assertions
about the order of the stubbed method calls (forExactly is defined on
the org.scalatest.Inspectors trait):

val requests = capturingAll(verify(myObj, times(2)).lookup _)

forExactly(1, requests) { (key, timeout) =>
  assert(key === "foo")
  assert(timeout === 10.seconds)
}

forExactly(1, requests) { (key, timeout) =>
  assert(key === "bar")
  assert(timeout === 2.seconds)
}

===================

#124

RB_ID=587072
TBR=true
  • Loading branch information
Zachary Voase authored and jenkins committed Feb 23, 2015
1 parent 310f6e1 commit cd2c096
Show file tree
Hide file tree
Showing 8 changed files with 1,409 additions and 1 deletion.
10 changes: 10 additions & 0 deletions codegen/Makefile
@@ -0,0 +1,10 @@
usage:
@echo 'No default make target is provided.'
@echo 'Run `make <filename>` to run the code generator on the given file,'
@echo 'or edit the Makefile to change how the generated file is produced.'
@false

../util-test/src/main/scala/com/twitter/util/testing/ArgumentCapture.scala: util-test/ArgumentCapture.scala.mako
mako-render util-test/ArgumentCapture.scala.mako > $@

.PHONY: usage
25 changes: 25 additions & 0 deletions codegen/README.markdown
@@ -0,0 +1,25 @@
# Code Generators

This directory contains templates and executables for generating Scala code. We
normally need to do this when defining a function or method that accepts a
variable number of type parameters. For example, we may want to define a `zip`
function for tuples of arbitrary size:

def zip[A, B](a: Seq[A], b: Seq[B]): Seq[(A, B)]
def zip[A, B, C](a: Seq[A], b: Seq[B], c: Seq[C]): Seq[(A, B, C)]
def zip[A, B, C, D] // and so on...

There isn't a native way to do this sort of thing in pure Scala (hence the
existence of Tuple1 through Tuple22, Function1...22, etc.).

## Usage

A Makefile exists which can build the code-generated files. To (re-)generate a
file:

make ../util-something/path/to/File.scala

The Mako Python templating library is used by some targets, and it can be
installed with `pip`:

pip install -r requirements.txt
1 change: 1 addition & 0 deletions codegen/requirements.txt
@@ -0,0 +1 @@
mako>=1.0.0
103 changes: 103 additions & 0 deletions codegen/util-test/ArgumentCapture.scala.mako
@@ -0,0 +1,103 @@
<%!
TYPE_VARS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
MAX_ARITY = 22
def nested_tuples(n):
"""
>>> nested_tuples(2)
(a, b)
>>> nested_tuples(3)
((a, b), c)
>>> nested_tuples(5)
'((((a, b), c), d), e)'
"""
def pairwise(elems):
if not elems: return ''
elif len(elems) == 1: return elems[0]
else: return '({}, {})'.format(pairwise(elems[:-1]), elems[-1])
return pairwise(TYPE_VARS.lower()[:n])
def flat_tuple(n):
return '({})'.format(', '.join(TYPE_VARS.lower()[:n]))
%>\
package com.twitter.util.testing

import java.util.{List => JList}
import org.mockito.ArgumentCaptor
import org.mockito.exceptions.Reporter
import scala.collection.JavaConverters._
import scala.reflect._

// This file was generated from codegen/util-test/ArgumentCapture.scala.mako

trait ArgumentCapture {
/**
* Enables capturingOne to be implemented over capturingAll with the same behavior as ArgumentCaptor.getValue
*/
private[this] def noArgWasCaptured(): Nothing = {
new Reporter().noArgumentValueWasCaptured() // this always throws an exception
throw new RuntimeException("this should be unreachable, but allows the method to be of type Nothing")
}

/**
* Capture all the invocations from a verify(mock).method(arg) call.
*
* Example:
* val requests = capturingAll(verify(myAPIEndpoint, times(4)).authenticate _)
* requests.length must equal (4)
*/
def capturingAll[T: ClassTag](f: T => _): Seq[T] = {
val argCaptor = ArgumentCaptor.forClass(classTag[T].runtimeClass.asInstanceOf[Class[T]])
f(argCaptor.capture())
argCaptor.getAllValues.asScala.toSeq
}

/**
* Capture an argument from a verify(mock).method(arg) call.
*
* Example:
* val request = capturingOne(verify(myAPIEndpoint).authenticate _)
* request.userId must equal (123L)
* request.password must equal ("reallySecurePassword")
*/
def capturingOne[T: ClassTag](f: T => _): T =
capturingAll[T](f).lastOption.getOrElse(noArgWasCaptured())\

% for i in xrange(2, MAX_ARITY + 1):
<%
types = ', '.join(TYPE_VARS[:i])
types_with_class_tags = ', '.join('{}: ClassTag'.format(t) for t in TYPE_VARS[:i])
iterable_args = ', '.join('arg{}: Iterable[{}]'.format(j, type) for j, type in enumerate(TYPE_VARS[:i]))
%>
/** Zip ${i} iterables together into a Seq of ${i}-tuples. */
private[this] def zipN[${types}](${iterable_args}): Seq[(${types})] = {
% if i == 2:
arg0.zip(arg1).toSeq
% else:
arg0
% for j in xrange(1, i):
.zip(arg${j})
% endfor
.map({ case ${nested_tuples(i)} => ${flat_tuple(i)} })
.toSeq
% endif
}

/** Capture all invocations of a mocked ${i}-ary method */
def capturingAll[${types_with_class_tags}](func: (${types}) => _): Seq[(${types})] = {
% for type in TYPE_VARS[:i]:
val argCaptor${type} = ArgumentCaptor.forClass(classTag[${type}].runtimeClass.asInstanceOf[Class[${type}]])
% endfor
func(${', '.join("argCaptor{}.capture()".format(type) for type in TYPE_VARS[:i])})
% for type in TYPE_VARS[:i]:
val args${type} = argCaptor${type}.getAllValues.asScala
% endfor
zipN(${', '.join("args{}".format(type) for type in TYPE_VARS[:i])})
}

/** Capture one invocation of a mocked ${i}-ary method */
def capturingOne[${types_with_class_tags}](func: (${types}) => _): (${types}) =
capturingAll[${types}](func).lastOption.getOrElse(noArgWasCaptured())
% endfor

}
5 changes: 4 additions & 1 deletion project/Build.scala
Expand Up @@ -300,7 +300,10 @@ object Util extends Build {
sharedSettings
).settings(
name := "util-test",
libraryDependencies += "org.scalatest" %% "scalatest" % "2.2.2"
libraryDependencies ++= Seq(
"org.scalatest" %% "scalatest" % "2.2.2",
"org.mockito" % "mockito-all" % "1.8.5"
)
).dependsOn(utilCore, utilLogging)


Expand Down
1 change: 1 addition & 0 deletions util-test/src/main/scala/BUILD
Expand Up @@ -5,6 +5,7 @@ scala_library(name='scala',
repo = artifactory,
),
dependencies=[
'3rdparty/jvm/org/mockito:mockito-all',
'3rdparty/jvm/org/scalatest',
'util/util-logging/src/main/scala',
'util/util-core/src/main/scala',
Expand Down

0 comments on commit cd2c096

Please sign in to comment.