Permalink
364 lines (275 sloc) 12.7 KB

Tester

This class provides a generic unit testing framework. It is already being used in nn package to verify the correctness of classes.

The framework is generally used as follows.

local mytest = torch.TestSuite()

local tester = torch.Tester()

function mytest.testA()
   local a = torch.Tensor{1, 2, 3}
   local b = torch.Tensor{1, 2, 4}
   tester:eq(a, b, "a and b should be equal")
end

function mytest.testB()
   local a = {2, torch.Tensor{1, 2, 2}}
   local b = {2, torch.Tensor{1, 2, 2.001}}
   tester:eq(a, b, 0.01, "a and b should be approximately equal")
end

function mytest.testC()
   local function myfunc()
      return "hello " .. world
   end
   tester:assertNoError(myfunc, "myfunc shouldn't give an error")
end

tester:add(mytest)
tester:run()

Running this code will report two test failures (and one test success). Generally it is better to put a single test case in each test function unless several very related test cases exist. The error report includes the message and line number of the error.

Running 3 tests
1/3 testB ............................................................... [PASS]
2/3 testA ............................................................... [FAIL]
3/3 testC ............................................................... [FAIL]
Completed 3 asserts in 3 tests with 2 failures and 0 errors
--------------------------------------------------------------------------------
testA
a and b should be equal
TensorEQ(==) violation: max diff=1, tolerance=0
stack traceback:
        ./test.lua:8: in function <./test.lua:5>

--------------------------------------------------------------------------------
testC
myfunc shouldn't give an error
ERROR violation: err=./test.lua:19: attempt to concatenate global 'world' (a nil value)
stack traceback:
        ./test.lua:21: in function <./test.lua:17>

--------------------------------------------------------------------------------
torch/torch/Tester.lua:383: An error was found while running tests!
stack traceback:
        [C]: in function 'assert'
        torch/torch/Tester.lua:383: in function 'run'
        ./test.lua:25: in main chunk

Historically, Tester has supported a variety of equality checks (asserteq, assertalmosteq, assertTensorEq, assertTableEq, and their negations). In general however, you should just use eq (or its negation ne). These functions do deep checking of many object types including recursive tables and tensors, and support a tolerance parameter for comparing numerical values (including tensors).

Many of the tester functions accept both an optional tolerance parameter and a message to display if the test case fails. For both convenience and backwards compatibility, these arguments can be supplied in either order.

torch.Tester()

Returns a new instance of torch.Tester class.

add(f, 'name')

Add f, either a test function or a table of test functions, to the tester.

If f is a function then names should be unique. There are a couple of special values for name: if it is _setUp or _tearDown, then the function will be called either before or after every test respectively, with the name of the test passed as a parameter.

If f is a table then name should be nil, and the names of the individual tests in the table will be taken from the corresponding table key. It's recommended you use TestSuite for tables of tests.

Returns the torch.Tester instance.

run(testNames)

Run tests that have been added by add(f, 'name'). While running it reports progress, and at the end gives a summary of all errors.

If a list of names testNames is passed, then all tests matching these names (using string.match) will be run; otherwise all tests will be run.

tester:run() -- runs all tests
tester:run("test1") -- runs the test named "test1"
tester:run({"test2", "test3"}) -- runs the tests named "test2" and "test3"

disable(testNames)

Prevent the given tests from running, where testNames can be a single string or list of strings. More precisely, when run is invoked, it will skip these tests, while still printing out an indication of skipped tests. This is useful for temporarily disabling tests without commenting out the code (for example, if they depend on upstream code that is currently broken), and explicitly flagging them as skipped.

Returns the torch.Tester instance.

local tester = torch.Tester()
local tests = torch.TestSuite()

function tests.brokenTest()
  -- ...
end

tester:add(tests):disable('brokenTest'):run()
Running 1 test
1/1 brokenTest .......................................................... [SKIP]
Completed 0 asserts in 1 test with 0 failures and 0 errors and 1 disabled

assert(condition [, message])

Check that condition is true (using the optional message if the test fails). Returns whether the test passed.

assertGeneralEq(got, expected [, tolerance] [, message])

General equality check between numbers, tables, strings, torch.Tensor objects, torch.Storage objects, etc.

Check that got and expected have the same contents, where tables are compared recursively, tensors and storages are compared elementwise, and numbers are compared within tolerance (default value 0). Other types are compared by strict equality. The optional message is used if the test fails. Returns whether the test passed.

eq(got, expected [, tolerance] [, message])

Convenience function; does the same as assertGeneralEq.

assertGeneralNe(got, unexpected [, tolerance] [, message])

General inequality check between numbers, tables, strings, torch.Tensor objects, torch.Storage objects, etc.

Check that got and unexpected have different contents, where tables are compared recursively, tensors and storages are compared elementwise, and numbers are compared within tolerance (default value 0). Other types are compared by strict equality. The optional message is used if the test fails. Returns whether the test passed.

ne(got, unexpected [, tolerance] [, message])

Convenience function; does the same as assertGeneralNe.

assertlt(a, b [, message])

Check that a < b (using the optional message if the test fails), where a and b are numbers. Returns whether the test passed.

assertgt(a, b [, message])

Check that a > b (using the optional message if the test fails), where a and b are numbers. Returns whether the test passed.

assertle(a, b [, message])

Check that a <= b (using the optional message if the test fails), where a and b are numbers. Returns whether the test passed.

assertge(a, b [, message])

Check that a >= b (using the optional message if the test fails), where a and b are numbers. Returns whether the test passed.

asserteq(a, b [, message])

Check that a == b (using the optional message if the test fails). Note that this uses the generic lua equality check, so objects such as tensors that have the same content but are distinct objects will fail this test; consider using assertGeneralEq() instead. Returns whether the test passed.

assertne(a, b [, message])

Check that a ~= b (using the optional message if the test fails). Note that this uses the generic lua inequality check, so objects such as tensors that have the same content but are distinct objects will pass this test; consider using assertGeneralNe() instead. Returns whether the test passed.

assertalmosteq(a, b [, tolerance] [, message])

Check that |a - b| <= tolerance (using the optional message if the test fails), where a and b are numbers, and tolerance is an optional number (default 1e-16). Returns whether the test passed.

assertTensorEq(ta, tb [, tolerance] [, message])

Check that max(abs(ta - tb)) <= tolerance (using the optional message if the test fails), where ta and tb are tensors, and tolerance is an optional number (default 1e-16). Tensors that are different types or sizes will cause this check to fail. Returns whether the test passed.

assertTensorNe(ta, tb [, tolerance] [, message])

Check that max(abs(ta - tb)) > tolerance (using the optional message if the test fails), where ta and tb are tensors, and tolerance is an optional number (default 1e-16). Tensors that are different types or sizes will cause this check to pass. Returns whether the test passed.

assertTableEq(ta, tb [, tolerance] [, message])

Check that the two tables have the same contents, comparing them recursively, where objects such as tensors are compared using their contents. Numbers (such as those appearing in tensors) are considered equal if their difference is at most the given tolerance.

assertTableNe(ta, tb [, tolerance] [, message])

Check that the two tables have distinct contents, comparing them recursively, where objects such as tensors are compared using their contents. Numbers (such as those appearing in tensors) are considered equal if their difference is at most the given tolerance.

assertError(f [, message])

Check that calling f() (via pcall) raises an error (using the optional message if the test fails). Returns whether the test passed.

assertNoError(f [, message])

Check that calling f() (via pcall) does not raise an error (using the optional message if the test fails). Returns whether the test passed.

assertErrorMsg(f, errmsg [, message])

Check that calling f() (via pcall) raises an error with the specific error message errmsg (using the optional message if the test fails). Returns whether the test passed.

assertErrorPattern(f, errPattern [, message])

Check that calling f() (via pcall) raises an error matching errPattern (using the optional message if the test fails). The matching is done using string.find; in particular substrings will match. Returns whether the test passed.

assertErrorObj(f, errcomp [, message])

Check that calling f() (via pcall) raises an error object err such that calling errcomp(err) returns true (using the optional message if the test fails). Returns whether the test passed.

setEarlyAbort(earlyAbort)

If earlyAbort == true then the testing will stop on the first test failure. By default this is off.

setRethrowErrors(rethrowErrors)

If rethrowErrors == true then lua errors encountered during the execution of the tests will be rethrown, instead of being caught by the tester. By default this is off.

setSummaryOnly(summaryOnly)

If summaryOnly == true, then only the pass / fail status of the tests will be printed out, rather than full error messages. By default, this is off.

TestSuite

A TestSuite is used in conjunction with Tester. It is created via torch.TestSuite(), and behaves like a plain lua table, except that it also checks that duplicate tests are not created. It is recommended that you always use a TestSuite instead of a plain table for your tests.

The following example code attempts to add a function with the same name twice to a TestSuite (a surprisingly common mistake), which gives an error.

> test = torch.TestSuite()
>
> function test.myTest()
>    -- ...
> end
>
> -- ...
>
> function test.myTest()
>    -- ...
> end
torch/TestSuite.lua:16: Test myTest is already defined.