From d8ff64c5707f716199e7782d630e44e2cf402a54 Mon Sep 17 00:00:00 2001 From: David Saxton Date: Wed, 24 Feb 2016 17:54:10 +0000 Subject: [PATCH] Replace torch.Tester with totem.Tester + extra stuff. This should bring a lot of benefit to code that uses torch.Tester (totem will eventually become deprecated). Note that torch.Tester and totem.Tester once shared the same code - this change brings it full circle. At a glance, extra functionality includes: - A general equality checker that accepts many different objects. - Deep table comparison with precision checking. - Stricter argument checking in using the test functions. - Better output. - torch.Storage comparison. - Extra features for fine-grained control of testing. --- .travis.yml | 1 + CMakeLists.txt | 2 +- TestSuite.lua | 30 ++ Tester.lua | 947 +++++++++++++++++++++++++++++++------ doc/tester.md | 344 +++++++++++--- init.lua | 1 + rocks/torch-scm-1.rockspec | 3 +- test/test.lua | 105 ++-- test/test_Tester.lua | 626 ++++++++++++++++++++++++ test/test_qr.lua | 2 +- test/test_sharedmem.lua | 2 +- test/test_writeObject.lua | 4 +- 12 files changed, 1775 insertions(+), 292 deletions(-) create mode 100644 TestSuite.lua create mode 100644 test/test_Tester.lua diff --git a/.travis.yml b/.travis.yml index fce9af43..46e731e0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -57,3 +57,4 @@ script: - ${TESTLUA} -ltorch -e "t=torch.test(); if t.errors[1] then os.exit(1) end" - cd test - ${TESTLUA} test_writeObject.lua +- ${TESTLUA} test_Tester.lua diff --git a/CMakeLists.txt b/CMakeLists.txt index 1c555d8b..611258b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -71,7 +71,7 @@ INCLUDE_DIRECTORIES(BEFORE "${CMAKE_CURRENT_SOURCE_DIR}/lib/luaT") LINK_DIRECTORIES("${LUA_LIBDIR}") SET(src DiskFile.c File.c MemoryFile.c PipeFile.c Storage.c Tensor.c Timer.c utils.c init.c TensorOperator.c TensorMath.c random.c Generator.c) -SET(luasrc init.lua File.lua Tensor.lua CmdLine.lua FFI.lua Tester.lua test/test.lua) +SET(luasrc init.lua File.lua Tensor.lua CmdLine.lua FFI.lua Tester.lua TestSuite.lua test/test.lua) # Necessary do generate wrapper ADD_TORCH_WRAP(tensormathwrap TensorMath.lua) diff --git a/TestSuite.lua b/TestSuite.lua new file mode 100644 index 00000000..630c2c94 --- /dev/null +++ b/TestSuite.lua @@ -0,0 +1,30 @@ +function torch.TestSuite() + local obj = { + __tests = {}, + __isTestSuite = true + } + + local metatable = {} + + function metatable:__index(key) + return self.__tests[key] + end + + function metatable:__newindex(key, value) + if self.__tests[key] ~= nil then + error("Test " .. tostring(key) .. " is already defined.") + end + if type(value) ~= "function" then + if type(value) == "table" then + error("Nested tables of tests are not supported") + else + error("Only functions are supported as members of a TestSuite") + end + end + self.__tests[key] = value + end + + setmetatable(obj, metatable) + + return obj +end diff --git a/Tester.lua b/Tester.lua index dbd1414b..3eedc6b0 100644 --- a/Tester.lua +++ b/Tester.lua @@ -1,231 +1,878 @@ + +require 'sys' -- for sys.COLORS + +-- Lua 5.2 compatibility +local unpack = unpack or table.unpack + +local check = {} -- helper functions, defined at the bottom of the file + local Tester = torch.class('torch.Tester') function Tester:__init() self.errors = {} self.tests = {} - self.testnames = {} - self.curtestname = '' + self.warnings = {} + self._warningCount = {} + self.disabledTests = {} + self._currentTestName = '' + + -- To maintain backwards compatibility (at least for a short while), + -- disable exact dimension checking of tensors when :assertTensorEq is + -- called. Thus {{1}} == {1} when this flag is true. + -- + -- Note that other methods that suppose tensor checking (such as + -- :assertGeneralEq) ignore this flag, since previously they didn't + -- exist or support tensor equality checks at all, so there is no + -- old code that uses these functions and relies on the behaviour. + -- + -- Note also that if the dimension check fails with this flag is true, then + -- will show a warning. + self._assertTensorEqIgnoresDims = true end +function Tester:setEarlyAbort(earlyAbort) + self.earlyAbort = earlyAbort +end + +function Tester:setRethrowErrors(rethrow) + self.rethrow = rethrow +end + +function Tester:setSummaryOnly(summaryOnly) + self.summaryOnly = summaryOnly +end + +-- Add a success to the test. +function Tester:_success() + local name = self._currentTestName + self.assertionPass[name] = self.assertionPass[name] + 1 + return true +end + +function Tester:_addDebugInfo(message) + local ss = debug.traceback('tester', 3) or '' + ss = ss:match('.-\n([^\n]+\n[^\n]+)\n[^\n]+xpcall') or '' + local name = self._currentTestName + return (name ~= '' and name .. '\n' or '') .. message .. '\n' .. ss +end -function Tester:assert_sub (condition, message) - self.countasserts = self.countasserts + 1 - if not condition then - local ss = debug.traceback('tester',2) - --print(ss) - ss = ss:match('[^\n]+\n[^\n]+\n([^\n]+\n[^\n]+)\n') - self.errors[#self.errors+1] = self.curtestname .. '\n' .. (message or '') .. '\n' .. ss .. '\n' +-- Add a failure to the test. +function Tester:_failure(message) + local name = self._currentTestName + self.assertionFail[name] = self.assertionFail[name] + 1 + self.errors[#self.errors + 1] = self:_addDebugInfo(message) + return false +end + +-- Add a warning to the test +function Tester:_warning(message) + local name = self._currentTestName + self._warningCount[name] = (self._warningCount[name] or 0) + 1 + self.warnings[#self.warnings + 1] = self:_addDebugInfo(message) +end + +-- Call this during a test run with `condition = true` to log a success, or with +-- `condition = false` to log a failure (using `message`). +function Tester:_assert_sub(condition, message) + if condition then + return self:_success() + else + return self:_failure(message) end end -function Tester:assert (condition, message) - self:assert_sub(condition,string.format('%s\n%s condition=%s',(message or ''),' BOOL violation ', tostring(condition))) +local function getMessage(message, ...) + assert(next{...} == nil, "Unexpected arguments passed to test function") + if message then + assert(type(message) == 'string', 'message parameter must be a string') + if message ~= '' then + return message .. '\n' + end + end + return '' end -function Tester:assertlt (val, condition, message) - self:assert_sub(val= 0, "tolerance cannot be negative") + else + error("Unrecognized argument; should be a tolerance or message", a) + end + end + message = message or '' + tolerance = tolerance or defaultTolerance + return tolerance, message end -function Tester:assertgt (val, condition, message) - self:assert_sub(val>condition,string.format('%s\n%s val=%s, condition=%s',(message or ''),' GT(>) violation ', tostring(val), tostring(condition))) +function Tester:assert(condition, ...) + local message = getMessage(...) + if type(condition) ~= 'boolean' then + self:_warning(" :assert should only be used for boolean conditions. " + .. "To check for non-nil variables, do this explicitly: " + .. "Tester:assert(var ~= nil).") + end + return self:_assert_sub(condition, + string.format('%sBOOL violation condition=%s', + message, tostring(condition))) end -function Tester:assertle (val, condition, message) - self:assert_sub(val<=condition,string.format('%s\n%s val=%s, condition=%s',(message or ''),' LE(<=) violation ', tostring(val), tostring(condition))) +function Tester:assertGeneralEq(got, expected, ...) + return self:_eqOrNeq(got, expected, false, ...) end -function Tester:assertge (val, condition, message) - self:assert_sub(val>=condition,string.format('%s\n%s val=%s, condition=%s',(message or ''),' GE(>=) violation ', tostring(val), tostring(condition))) +function Tester:eq(got, expected, ...) + return self:assertGeneralEq(got, expected, ...) end -function Tester:asserteq (val, condition, message) - self:assert_sub(val==condition,string.format('%s\n%s val=%s, condition=%s',(message or ''),' EQ(==) violation ', tostring(val), tostring(condition))) +function Tester:assertGeneralNe(got, unexpected, ...) + return self:_eqOrNeq(got, unexpected, true, ...) end -function Tester:assertalmosteq (a, b, condition, message) - condition = condition or 1e-16 - local err = math.abs(a-b) - self:assert_sub(err < condition, string.format('%s\n%s val=%s, condition=%s',(message or ''),' ALMOST_EQ(==) violation ', tostring(err), tostring(condition))) +function Tester:ne(got, unexpected, ...) + return self:assertGeneralNe(got, unexpected, ...) end -function Tester:assertne (val, condition, message) - self:assert_sub(val~=condition,string.format('%s\n%s val=%s, condition=%s',(message or ''),' NE(~=) violation ', tostring(val), tostring(condition))) +function Tester:_eqOrNeq(got, expected, negate, ...) + local tolerance, message = getToleranceAndMessage(0, ...) + local success, subMessage = check.areEq(got, expected, tolerance, negate) + subMessage = subMessage or '' + return self:_assert_sub(success, message .. subMessage) end -function Tester:assertTensorEq(ta, tb, condition, message) - if ta:dim() == 0 and tb:dim() == 0 then - return +function Tester:assertlt(a, b, ...) + local message = getMessage(...) + return self:_assert_sub(a < b, + string.format('%sLT failed: %s >= %s', + message, tostring(a), tostring(b))) +end + +function Tester:assertgt(a, b, ...) + local message = getMessage(...) + return self:_assert_sub(a > b, + string.format('%sGT failed: %s <= %s', + message, tostring(a), tostring(b))) +end + +function Tester:assertle(a, b, ...) + local message = getMessage(...) + return self:_assert_sub(a <= b, + string.format('%sLE failed: %s > %s', + message, tostring(a), tostring(b))) +end + +function Tester:assertge(a, b, ...) + local message = getMessage(...) + return self:_assert_sub(a >= b, + string.format('%sGE failed: %s < %s', + message, tostring(a), tostring(b))) +end + +function Tester:assertalmosteq(a, b, ...) + local tolerance, message = getToleranceAndMessage(1e-16, ...) + local diff = math.abs(a - b) + return self:_assert_sub( + diff <= tolerance, + string.format( + '%sALMOST_EQ failed: %s ~= %s with tolerance=%s', + message, tostring(a), tostring(b), tostring(tolerance))) +end + +function Tester:asserteq(a, b, ...) + local message = getMessage(...) + return self:_assert_sub(a == b, + string.format('%sEQ failed: %s ~= %s', + message, tostring(a), tostring(b))) +end + +function Tester:assertne(a, b, ...) + local message = getMessage(...) + if type(a) == type(b) and type(a) == 'table' or type(a) == 'userdata' then + self:_warning(" :assertne should only be used to compare basic lua " + .. "objects (numbers, booleans, etc). Consider using " + .. "either :assertGeneralNe or :assert(a ~= b).") end - local diff = ta-tb - local err = diff:abs():max() - self:assert_sub(err<=condition,string.format('%s\n%s val=%s, condition=%s',(message or ''),' TensorEQ(==) violation ', tostring(err), tostring(condition))) + return self:_assert_sub(a ~= b, + string.format('%sNE failed: %s == %s', + message, tostring(a), tostring(b))) end -function Tester:assertTensorNe(ta, tb, condition, message) - local diff = ta-tb - local err = diff:abs():max() - self:assert_sub(err>condition,string.format('%s\n%s val=%s, condition=%s',(message or ''),' TensorNE(~=) violation ', tostring(err), tostring(condition))) +function Tester:assertTensorEq(ta, tb, ...) + return self:_assertTensorEqOrNeq(ta, tb, false, ...) end -local function areTablesEqual(ta, tb) - local function isIncludedIn(ta, tb) - if type(ta) ~= 'table' or type(tb) ~= 'table' then - return ta == tb - end - for k, v in pairs(tb) do - if not areTablesEqual(ta[k], v) then return false end - end - return true +function Tester:assertTensorNe(ta, tb, ...) + return self:_assertTensorEqOrNeq(ta, tb, true, ...) +end + +function Tester:_assertTensorEqOrNeq(ta, tb, negate, ...) + assert(torch.isTensor(ta), "First argument should be a Tensor") + assert(torch.isTensor(tb), "Second argument should be a Tensor") + + local tolerance, message = getToleranceAndMessage(0, ...) + local success, subMessage = + check.areTensorsEq(ta, tb, tolerance, negate, + self._assertTensorEqIgnoresDims) + subMessage = subMessage or '' + + if self._assertTensorEqIgnoresDims and (not negate) and success + and not ta:isSameSizeAs(tb) then + self:_warning("Tensors have the same content but different dimensions. " + .. "For backwards compatability, they are considered equal, " + .. "but this may change in the future. Consider using :eq " + .. "to check for equality instead.") end - return isIncludedIn(ta, tb) and isIncludedIn(tb, ta) + return self:_assert_sub(success, message .. subMessage) +end + +function Tester:assertTableEq(ta, tb, ...) + return self:_assertTableEqOrNeq(ta, tb, false, ...) end -function Tester:assertTableEq(ta, tb, message) - self:assert_sub(areTablesEqual(ta, tb), string.format('%s\n%s',(message or ''),' TableEQ(==) violation ')) +function Tester:assertTableNe(ta, tb, ...) + return self:_assertTableEqOrNeq(ta, tb, true, ...) end -function Tester:assertTableNe(ta, tb, message) - self:assert_sub(not areTablesEqual(ta, tb), string.format('%s\n%s',(message or ''),' TableEQ(==) violation ')) +function Tester:_assertTableEqOrNeq(ta, tb, negate, ...) + assert(type(ta) == 'table', "First argument should be a Table") + assert(type(tb) == 'table', "Second argument should be a Table") + return self:_eqOrNeq(ta, tb, negate, ...) end -function Tester:assertError(f, message) - return self:assertErrorObj(f, function(err) return true end, message) +function Tester:assertError(f, ...) + return self:assertErrorObj(f, function() return true end, ...) end -function Tester:assertErrorMsg(f, errmsg, message) - return self:assertErrorObj(f, function(err) return err == errmsg end, message) +function Tester:assertNoError(f, ...) + local message = getMessage(...) + local status, err = pcall(f) + return self:_assert_sub(status, + string.format('%sERROR violation: err=%s', message, + tostring(err))) end -function Tester:assertErrorPattern(f, errPattern, message) - return self:assertErrorObj(f, function(err) return string.find(err, errPattern) ~= nil end, message) +function Tester:assertErrorMsg(f, errmsg, ...) + return self:assertErrorObj(f, function(err) return err == errmsg end, ...) end -function Tester:assertErrorObj(f, errcomp, message) - -- errcomp must be a function that compares the error object to its expected value +function Tester:assertErrorPattern(f, errPattern, ...) + local function errcomp(err) + return string.find(err, errPattern) ~= nil + end + return self:assertErrorObj(f, errcomp, ...) +end + +function Tester:assertErrorObj(f, errcomp, ...) + local message = getMessage(...) local status, err = pcall(f) - self:assert_sub(status == false and errcomp(err), string.format('%s\n%s err=%s', (message or ''),' ERROR violation ', tostring(err))) + return self:_assert_sub((not status) and errcomp(err), + string.format('%sERROR violation: err=%s', message, + tostring(err))) +end + +function Tester:add(f, name) + if type(f) == "table" then + assert(name == nil, "Name parameter is forbidden for a table of tests, " + .. "since its use is ambiguous") + if f.__isTestSuite then + f = f.__tests + else + self:_warning("Should use TestSuite rather than plain lua table") + end + for i, v in pairs(f) do + -- We forbid nested tests because the "expected" behaviour when a named + -- test is run in the case that the named test is in fact a table of + -- tests is not supported. Similar issue with _setUp and _tearDown + -- functions inside nested tests. + assert(type(v) ~= 'table', "Nested sets of tests are not supported") + self:add(v, i) + end + return self + end + + assert(type(f) == 'function', + "Only tables of functions and functions supported") + + if name == '_setUp' then + assert(not self._setUp, "Only one set-up function allowed") + self._setUp = f + elseif name == '_tearDown' then + assert(not self._tearDown, "Only one tear-down function allowed") + self._tearDown = f + else + name = name or 'unknown' + if self.tests[name] ~= nil then + error('Test with name ' .. name .. ' already exists!') + end + self.tests[name] = f + end + return self +end + +function Tester:disable(testNames) + if type(testNames) == 'string' then + testNames = {testNames} + end + assert(type(testNames) == 'table', "Expecting name or list for disable") + for _, name in ipairs(testNames) do + assert(self.tests[name], "Unrecognized test '" .. name .. "'") + self.disabledTests[name] = true + end + return self +end + +function Tester:run(testNames) + local tests = self:_getTests(testNames) + self.assertionPass = {} + self.assertionFail = {} + self.haveWarning = {} + self.testError = {} + for name in pairs(tests) do + self.assertionPass[name] = 0 + self.assertionFail[name] = 0 + self.testError[name] = 0 + self._warningCount[name] = 0 + end + self:_run(tests) + self:_report(tests) + + -- Throws an error on test failure/error, so that test script returns + -- with nonzero return value. + for name in pairs(tests) do + assert(self.assertionFail[name] == 0, + 'An error was found while running tests!') + assert(self.testError[name] == 0, + 'An error was found while running tests!') + end + + return 0 +end + +local function pluralize(num, str) + local stem = num .. ' ' .. str + if num == 1 then + return stem + else + return stem .. 's' + end +end + +local NCOLS = 80 +local coloured +local c = {} +if arg then -- have we been invoked from the commandline? + c = sys.COLORS + coloured = function(str, colour) + return colour .. str .. c.none + end +else + coloured = function(str) + return str + end end +function Tester:_run(tests) + local ntests = 0 + for _ in pairs(tests) do + ntests = ntests + 1 + end + + local ntestsAsString = string.format('%u', ntests) + local cfmt = string.format('%%%uu/%u ', ntestsAsString:len(), ntestsAsString) + local cfmtlen = ntestsAsString:len() * 2 + 2 + + local function bracket(str) + return '[' .. str .. ']' + end + + io.write('Running ' .. pluralize(ntests, 'test') .. '\n') + local i = 1 + for name, fn in pairs(tests) do + self._currentTestName = name + + -- TODO: compute max length of name and cut it down to size if needed + local strinit = coloured(string.format(cfmt, i), c.cyan) + .. self._currentTestName .. ' ' + .. string.rep('.', + NCOLS - 6 - 2 - + cfmtlen - self._currentTestName:len()) + .. ' ' + io.write(strinit .. bracket(coloured('WAIT', c.cyan))) + io.flush() + + local status, message, pass, skip + if self.disabledTests[name] then + skip = true + else + skip = false + if self._setUp then + self._setUp(name) + end + if self.rethrow then + status = true + local nerr = #self.errors + message = fn() + pass = nerr == #self.errors + else + status, message, pass = self:_pcall(fn) + end + if self._tearDown then + self._tearDown(name) + end + end + + io.write('\r') + io.write(strinit) + + if skip then + io.write(bracket(coloured('SKIP', c.yellow))) + elseif not status then + self.testError[name] = 1 + io.write(bracket(coloured('ERROR', c.magenta))) + elseif not pass then + io.write(bracket(coloured('FAIL', c.red))) + else + io.write(bracket(coloured('PASS', c.green))) + if self._warningCount[name] > 0 then + io.write('\n' .. string.rep(' ', NCOLS - 10)) + io.write(bracket(coloured('+warning', c.yellow))) + end + end + io.write('\n') + io.flush() + + if self.earlyAbort and (i < ntests) and (not status or not pass) then + io.write('Aborting on first error, not all tests have been executed\n') + break + end + i = i + 1 -function Tester:pcall(f) + collectgarbage() + end +end + +function Tester:_pcall(f) local nerr = #self.errors - -- local res = f() local stat, result = xpcall(f, debug.traceback) if not stat then - if result:find("interrupted!") then - self:report() - error("interrupted!") - end - self.errors[#self.errors+1] = self.curtestname .. '\n Function call failed \n' .. result .. '\n' + self.errors[#self.errors + 1] = + self._currentTestName .. '\n Function call failed\n' .. result .. '\n' end return stat, result, stat and (nerr == #self.errors) - -- return true, res, nerr == #self.errors end -function Tester:report(tests) - if not tests then - tests = self.tests +function Tester:_getTests(testNames) + if testNames == nil then + return self.tests end - print('Completed ' .. self.countasserts .. ' asserts in ' .. #tests .. ' tests with ' .. #self.errors .. ' errors') - print() - print(string.rep('-',80)) - for i,v in ipairs(self.errors) do - print(v) - print(string.rep('-',80)) + if type(testNames) == 'string' then + testNames = {testNames} + end + assert(type(testNames) == 'table', + "Only accept a name or table of test names (or nil for all tests)") + + local function getMatchingNames(pattern) + local matchingNames = {} + for name in pairs(self.tests) do + if string.match(name, pattern) then + table.insert(matchingNames, name) + end + end + return matchingNames + end + + local tests = {} + for _, pattern in ipairs(testNames) do + local matchingNames = getMatchingNames(pattern) + assert(#matchingNames > 0, "Couldn't find test '" .. pattern .. "'") + for _, name in ipairs(matchingNames) do + tests[name] = self.tests[name] + end end + return tests end -function Tester:run(run_tests) - local tests, testnames +function Tester:_report(tests) + local ntests = 0 + local nfailures = 0 + local nerrors = 0 + local nskipped = 0 + local nwarnings = 0 self.countasserts = 0 - tests = self.tests - testnames = self.testnames - if type(run_tests) == 'string' then - run_tests = {run_tests} - end - if type(run_tests) == 'table' then - tests = {} - testnames = {} - for i,fun in ipairs(self.tests) do - for j,name in ipairs(run_tests) do - if self.testnames[i] == name or i == name then - tests[#tests+1] = self.tests[i] - testnames[#testnames+1] = self.testnames[i] - end - end + for name in pairs(tests) do + ntests = ntests + 1 + self.countasserts = self.countasserts + self.assertionFail[name] + + self.assertionPass[name] + if self.assertionFail[name] > 0 then + nfailures = nfailures + 1 + end + if self.testError[name] > 0 then + nerrors = nerrors + 1 end + if self._warningCount[name] > 0 then + nwarnings = nwarnings + 1 + end + if self.disabledTests[name] then + nskipped = nskipped + 1 + end + end + if self._warningCount[''] then + nwarnings = nwarnings + self._warningCount[''] end - self:_run(tests, testnames) - self:report(tests) + io.write('Completed ' .. pluralize(self.countasserts, 'assert')) + io.write(' in ' .. pluralize(ntests, 'test') .. ' with ') + io.write(coloured(pluralize(nfailures, 'failure'), + nfailures == 0 and c.green or c.red)) + io.write(' and ') + io.write(coloured(pluralize(nerrors, 'error'), + nerrors == 0 and c.green or c.magenta)) + if nwarnings > 0 then + io.write(' and ') + io.write(coloured(pluralize(nwarnings, 'warning'), c.yellow)) + end + if nskipped > 0 then + io.write(' and ') + io.write(coloured(nskipped .. ' disabled', c.yellow)) + end + io.write('\n') + + -- Prints off a message separated by ----- + local haveSection = false + local function addSection(text) + local function printDashes() + io.write(string.rep('-', NCOLS) .. '\n') + end + if not haveSection then + printDashes() + haveSection = true + end + io.write(text .. '\n') + printDashes() + end + + if not self.summaryOnly then + for _, v in ipairs(self.errors) do + addSection(v) + end + for _, v in ipairs(self.warnings) do + addSection(v) + end + end end ---[[ Run exactly the given test functions with the given names. -This doesn't do any matching or filtering, or produce a final report. It -is internal to Tester:run(). +--[[ Tests for tensor equality between two tensors of matching sizes and types. + +Tests whether the maximum element-wise difference between `ta` and `tb` is less +than or equal to `tolerance`. + +Arguments: +* `ta` (tensor) +* `tb` (tensor) +* `tolerance` (number) maximum elementwise difference between `ta` and `tb`. +* `negate` (boolean) if true, we invert success and failure. +* `storage` (boolean) if true, we print an error message referring to Storages + rather than Tensors. + +Returns: +1. success, boolean that indicates success +2. failure_message, string or nil ]] -function Tester:_run(tests, testnames) - print('Running ' .. #tests .. ' tests') - local statstr = string.rep('_',#tests) - local pstr = '' - io.write(statstr .. '\r') - for i,v in ipairs(tests) do - self.curtestname = testnames[i] - - --clear - io.write('\r' .. string.rep(' ', pstr:len())) - io.flush() - --write - pstr = statstr:sub(1,i-1) .. '|' .. statstr:sub(i+1) .. ' ==> ' .. self.curtestname - io.write('\r' .. pstr) - io.flush() +function check.areSameFormatTensorsEq(ta, tb, tolerance, negate, storage) + local function ensureHasAbs(t) + -- Byte, Char and Short Tensors don't have abs + return t.abs and t or t:double() + end - local stat, message, pass = self:pcall(v) + ta = ensureHasAbs(ta) + tb = ensureHasAbs(tb) - if pass then - --io.write(string.format('\b_')) - statstr = statstr:sub(1,i-1) .. '_' .. statstr:sub(i+1) - else - statstr = statstr:sub(1,i-1) .. '*' .. statstr:sub(i+1) - --io.write(string.format('\b*')) - end + local diff = ta:clone():add(-1, tb):abs() + local err = diff:max() + local success = err <= tolerance + if negate then + success = not success + end - if not stat then - -- print() - -- print('Function call failed: Test No ' .. i .. ' ' .. testnames[i]) - -- print(message) - end - collectgarbage() + local errMessage + if not success then + local prefix = storage and 'Storage' or 'Tensor' + local violation = negate and 'NE(==)' or 'EQ(==)' + errMessage = string.format('%s%s violation: max diff=%s, tolerance=%s', + prefix, + violation, + tostring(err), + tostring(tolerance)) end - --clear - io.write('\r' .. string.rep(' ', pstr:len())) - io.flush() - -- write finish - pstr = statstr .. ' ==> Done ' - io.write('\r' .. pstr) - io.flush() - print() - print() + + return success, errMessage end -function Tester:add(f,name) - name = name or 'unknown' - if type(f) == "table" then - local orderedNames = {} - for n,_ in pairs(f) do - table.insert(orderedNames, n) +--[[ Tests for tensor equality. + +Tests whether the maximum element-wise difference between `ta` and `tb` is less +than or equal to `tolerance`. + +Arguments: +* `ta` (tensor) +* `tb` (tensor) +* `tolerance` (number) maximum elementwise difference between `ta` and `tb`. +* `negate` (boolean) if negate is true, we invert success and failure. +* `ignoreTensorDims` (boolean, default false) if true, then tensors of the same + size but different dimensions can still be considered equal, e.g., + {{1}} == {1}. For backwards compatibility. + +Returns: +1. success, boolean that indicates success +2. failure_message, string or nil +]] +function check.areTensorsEq(ta, tb, tolerance, negate, ignoreTensorDims) + ignoreTensorDims = ignoreTensorDims or false + + if not ignoreTensorDims and ta:dim() ~= tb:dim() then + return negate, 'The tensors have different dimensions' + end + + if ta:type() ~= tb:type() then + return negate, 'The tensors have different types' + end + + -- If we are comparing two empty tensors, return true. + -- This is needed because some functions below cannot be applied to tensors + -- of dimension 0. + if ta:dim() == 0 and tb:dim() == 0 then + return not negate, 'Both tensors are empty' + end + + local sameSize + if ignoreTensorDims then + sameSize = ta:nElement() == tb:nElement() + else + sameSize = ta:isSameSizeAs(tb) + end + if not sameSize then + return negate, 'The tensors have different sizes' + end + + return check.areSameFormatTensorsEq(ta, tb, tolerance, negate, false) +end + +local typesMatching = { + ['torch.ByteStorage'] = torch.ByteTensor, + ['torch.CharStorage'] = torch.CharTensor, + ['torch.ShortStorage'] = torch.ShortTensor, + ['torch.IntStorage'] = torch.IntTensor, + ['torch.LongStorage'] = torch.LongTensor, + ['torch.FloatStorage'] = torch.FloatTensor, + ['torch.DoubleStorage'] = torch.DoubleTensor, +} + +--[[ Tests for storage equality. + +Tests whether the maximum element-wise difference between `sa` and `sb` is less +than or equal to `tolerance`. + +Arguments: +* `sa` (storage) +* `sb` (storage) +* `tolerance` (number) maximum elementwise difference between `a` and `b`. +* `negate` (boolean) if negate is true, we invert success and failure. + +Returns: +1. success, boolean that indicates success +2. failure_message, string or nil +]] +function check.areStoragesEq(sa, sb, tolerance, negate) + if sa:size() ~= sb:size() then + return negate, 'The storages have different sizes' + end + + local typeOfsa = torch.type(sa) + local typeOfsb = torch.type(sb) + + if typeOfsa ~= typeOfsb then + return negate, 'The storages have different types' + end + + local ta = typesMatching[typeOfsa](sa) + local tb = typesMatching[typeOfsb](sb) + + return check.areSameFormatTensorsEq(ta, tb, tolerance, negate, true) +end + +--[[ Tests for general (deep) equality. + +The types of `got` and `expected` must match. +Tables are compared recursively. Keys and types of the associated values must +match, recursively. Numbers are compared with the given tolerance. +Torch tensors and storages are compared with the given tolerance on their +elementwise difference. Other types are compared for strict equality with the +regular Lua == operator. + +Arguments: +* `got` +* `expected` +* `tolerance` (number) maximum elementwise difference between `a` and `b`. +* `negate` (boolean) if negate is true, we invert success and failure. + +Returns: +1. success, boolean that indicates success +2. failure_message, string or nil +]] +function check.areEq(got, expected, tolerance, negate) + local errMessage + if type(got) ~= type(expected) then + if not negate then + errMessage = 'EQ failed: values have different types (first: ' + .. type(got) .. ', second: ' .. type(expected) .. ')' end - table.sort(orderedNames) - for _,n in pairs(orderedNames) do - self:add(f[n], n) + return negate, errMessage + elseif type(got) == 'number' then + local diff = math.abs(got - expected) + local ok = (diff <= tolerance) + if negate then + ok = not ok end - elseif type(f) == "function" then - self.tests[#self.tests+1] = f - self.testnames[#self.tests] = name + if not ok then + if negate then + errMessage = string.format("NE failed: %s == %s", + tostring(got), tostring(expected)) + else + errMessage = string.format("EQ failed: %s ~= %s", + tostring(got), tostring(expected)) + end + if tolerance > 0 then + errMessage = errMessage .. " with tolerance=" .. tostring(tolerance) + end + end + return ok, errMessage + elseif type(expected) == "table" then + return check.areTablesEq(got, expected, tolerance, negate) + elseif torch.isTensor(got) then + return check.areTensorsEq(got, expected, tolerance, negate) + elseif torch.isStorage(got) then + return check.areStoragesEq(got, expected, tolerance, negate) else - error('Tester:add(f) expects a function or a table of functions') + -- Below: we have the same type which is either userdata or a lua type + -- which is not a number. + local ok = (got == expected) + if negate then + ok = not ok + end + if not ok then + if negate then + errMessage = string.format("NE failed: %s (%s) == %s (%s)", + tostring(got), type(got), + tostring(expected), type(expected)) + else + errMessage = string.format("EQ failed: %s (%s) ~= %s (%s)", + tostring(got), type(got), + tostring(expected), type(expected)) + end + end + return ok, errMessage + end +end + +--[[ Tests for (deep) table equality. + +Tables are compared recursively. Keys and types of the associated values must +match, recursively. Numbers are compared with the given tolerance. +Torch tensors and storages are compared with the given tolerance on their +elementwise difference. Other types are compared for strict equality with the +regular Lua == operator. + +Arguments: +* `t1` (table) +* `t2` (table) +* `tolerance` (number) maximum elementwise difference between `a` and `b`. +* `negate` (boolean) if negate is true, we invert success and failure. + +Returns: +1. success, boolean that indicates success +2. failure_message, string or nil +]] +function check.areTablesEq(t1, t2, tolerance, negate) + -- Implementation detail: Instead of doing a depth-first table comparison + -- check (for example, using recursion), let's do a breadth-first search + -- using a queue. Why? Because if we have two tables that are quite deep + -- (e.g., a gModule from nngraph), then if they are different then it's + -- more useful to the user to show how they differ at as-shallow-a-depth + -- as possible. + local queue = {} + queue._head = 1 + queue._tail = 1 + function queue.isEmpty() + return queue._tail == queue._head + end + function queue.pop() + queue._head = queue._head + 1 + return queue[queue._head - 1] + end + function queue.push(value) + queue[queue._tail] = value + queue._tail = queue._tail + 1 + end + + queue.push({t1, t2}) + while not queue.isEmpty() do + local location + t1, t2, location = unpack(queue.pop()) + + local function toSublocation(key) + local keyAsString = tostring(key) + return (location and location .. "." .. keyAsString) or keyAsString + end + + for key, value1 in pairs(t1) do + local sublocation = toSublocation(key) + if t2[key] == nil then + return negate, string.format( + "Entry %s missing in second table (is %s in first)", + sublocation, tostring(value1)) + end + local value2 = t2[key] + if type(value1) == 'table' and type(value2) == 'table' then + queue.push({value1, value2, sublocation}) + else + local ok, message = check.areEq(value1, value2, tolerance, false) + if not ok then + message = 'At table location ' .. sublocation .. ': ' .. message + return negate, message + end + end + end + + for key, value2 in pairs(t2) do + local sublocation = toSublocation(key) + if t1[key] == nil then + return negate, string.format( + "Entry %s missing in first table (is %s in second)", + sublocation, tostring(value2)) + end + end end + return not negate, 'The tables are equal' end diff --git a/doc/tester.md b/doc/tester.md index 36d9a4c4..eab061ac 100644 --- a/doc/tester.md +++ b/doc/tester.md @@ -1,64 +1,85 @@ # Tester # -This class provides a generic unit testing framework. It is already +This class provides a generic unit testing framework. It is already being used in [nn](../index.md) package to verify the correctness of classes. The framework is generally used as follows. ```lua -mytest = {} +local mytest = torch.TestSuite() -tester = torch.Tester() +local tester = torch.Tester() -function mytest.TestA() - local a = 10 - local b = 10 - tester:asserteq(a,b,'a == b') - tester:assertne(a,b,'a ~= b') +function mytest.testA() + local a = torch.Tensor{1, 2, 3} + local b = torch.Tensor{1, 2, 4} + tester:eq(a, b, "a and b should be equal") end -function mytest.TestB() - local a = 10 - local b = 9 - tester:assertlt(a,b,'a < b') - tester:assertgt(a,b,'a > b') +function mytest.testB() + local a = {2, torch.Tensor{1, 2, 2}} + local b = {2, torch.Tensor{1, 2, 2.001}} + tester:eq(a, b, 0.01, "a and b should be approximately equal") +end + +function mytest.testC() + local function myfunc() + return "hello " .. world + end + tester:assertNoError(myfunc, "myfunc shouldn't give an error") end tester:add(mytest) tester:run() - ``` -Running this code will report 2 errors in 2 test functions. Generally it is -better to put a single test case in each test function unless several very related -test cases exist. The error report includes the message and line number of the error. +Running this code will report two test failures (and one test success). +Generally it is better to put a single test case in each test function unless +several very related test cases exist. +The error report includes the message and line number of the error. ``` - -Running 2 tests -** ==> Done - -Completed 2 tests with 2 errors - +Running 3 tests +1/3 testB ............................................................... [PASS] +2/3 testA ............................................................... [FAIL] +3/3 testC ............................................................... [FAIL] +Completed 3 asserts in 3 tests with 2 failures and 0 errors -------------------------------------------------------------------------------- -TestB -a < b - LT(<) violation val=10, condition=9 - ...y/usr.t7/local.master/share/lua/5.1/torch/Tester.lua:23: in function 'assertlt' - [string "function mytest.TestB()..."]:4: in function 'f' +testA +a and b should be equal +TensorEQ(==) violation: max diff=1, tolerance=0 +stack traceback: + ./test.lua:8: in function <./test.lua:5> -------------------------------------------------------------------------------- -TestA -a ~= b - NE(~=) violation val=10, condition=10 - ...y/usr.t7/local.master/share/lua/5.1/torch/Tester.lua:38: in function 'assertne' - [string "function mytest.TestA()..."]:5: in function 'f' +testC +myfunc shouldn't give an error +ERROR violation: err=./test.lua:19: attempt to concatenate global 'world' (a nil value) +stack traceback: + ./test.lua:21: in function <./test.lua:17> -------------------------------------------------------------------------------- - +torch/torch/Tester.lua:383: An error was found while running tests! +stack traceback: + [C]: in function 'assert' + torch/torch/Tester.lua:383: in function 'run' + ./test.lua:25: in main chunk ``` +Historically, Tester has supported a variety of equality checks +([asserteq](#torch.Tester.asserteq), +[assertalmosteq](#torch.Tester.assertalmosteq), +[assertTensorEq](#torch.Tester.assertTensorEq), +[assertTableEq](#torch.Tester.assertTableEq), and their negations). In general +however, you should just use [eq](#torch.Tester.eq) (or its negation +[ne](#torch.Tester.ne)). These functions do deep checking of many object types +including recursive tables and tensors, and support a +tolerance parameter for comparing numerical values (including tensors). + +Many of the tester functions accept both an optional `tolerance` parameter and a +`message` to display if the test case fails. For both convenience and backwards +compatibility, these arguments can be supplied in either order. ### torch.Tester() ### @@ -68,84 +89,275 @@ Returns a new instance of `torch.Tester` class. ### add(f, 'name') ### -Adds a new test function with name `name`. The test function is stored in `f`. -The function is supposed to run without any arguments and not return any values. +Add `f`, either a test function or a table of test functions, to the tester. - -### add(ftable) ### +If `f` is a function then names should be unique. There are a couple of special +values for `name`: if it is `_setUp` or `_tearDown`, then the function will be +called either *before* or *after* every test respectively, with the name of the +test passed as a parameter. + +If `f` is a table then `name` should be nil, and the names of the individual +tests in the table will be taken from the corresponding table key. It's +recommended you use [TestSuite](#torch.TestSuite.dok) for tables of tests. + +Returns the torch.Tester instance. + + +### run(testNames) ### + +Run tests that have been added by [add(f, 'name')](#torch.Tester.add). +While running it reports progress, and at the end gives a summary of all errors. -Recursively adds all function entries of the table `ftable` as tests. This table -can only have functions or nested tables of functions. +If a list of names `testNames` is passed, then all tests matching these names +(using `string.match`) will be run; otherwise all tests will be run. + +```lua +tester:run() -- runs all tests +tester:run("test1") -- runs the test named "test1" +tester:run({"test2", "test3"}) -- runs the tests named "test2" and "test3" +``` + + +### disable(testNames) ### + +Prevent the given tests from running, where `testNames` can be a single string +or list of strings. More precisely, when [run](#torch.Tester.run) +is invoked, it will skip these tests, while still printing out an indication of +skipped tests. This is useful for temporarily disabling tests without +commenting out the code (for example, if they depend on upstream code that is +currently broken), and explicitly flagging them as skipped. + +Returns the torch.Tester instance. + +```lua +local tester = torch.Tester() +local tests = torch.TestSuite() + +function tests.brokenTest() + -- ... +end + +tester:add(tests):disable('brokenTest'):run() +``` + +``` +Running 1 test +1/1 brokenTest .......................................................... [SKIP] +Completed 0 asserts in 1 test with 0 failures and 0 errors and 1 disabled +``` ### assert(condition [, message]) ### -Saves an error if condition is not true with the optional message. +Check that `condition` is true (using the optional `message` if the test +fails). +Returns whether the test passed. + + +### assertGeneralEq(got, expected [, tolerance] [, message]) ### + +General equality check between numbers, tables, strings, `torch.Tensor` +objects, `torch.Storage` objects, etc. + +Check that `got` and `expected` have the same contents, where tables are +compared recursively, tensors and storages are compared elementwise, and numbers +are compared within `tolerance` (default value `0`). Other types are compared by +strict equality. The optional `message` is used if the test fails. +Returns whether the test passed. + + +### eq(got, expected [, tolerance] [, message]) ### + +Convenience function; does the same as +[assertGeneralEq](#torch.Tester.assertGeneralEq). + + +### assertGeneralNe(got, unexpected [, tolerance] [, message]) ### + +General inequality check between numbers, tables, strings, `torch.Tensor` +objects, `torch.Storage` objects, etc. + +Check that `got` and `unexpected` have different contents, where tables are +compared recursively, tensors and storages are compared elementwise, and numbers +are compared within `tolerance` (default value `0`). Other types are compared by +strict equality. The optional `message` is used if the test fails. +Returns whether the test passed. + + +### ne(got, unexpected [, tolerance] [, message]) ### + +Convenience function; does the same as +[assertGeneralNe](#torch.Tester.assertGeneralNe). -### assertlt(val, condition [, message]) ### +### assertlt(a, b [, message]) ### -Saves an error if `val < condition` is not true with the optional message. +Check that `a < b` (using the optional `message` if the test fails), +where `a` and `b` are numbers. +Returns whether the test passed. -### assertgt(val, condition [, message]) ### +### assertgt(a, b [, message]) ### -Saves an error if `val > condition` is not true with the optional message. +Check that `a > b` (using the optional `message` if the test fails), +where `a` and `b` are numbers. +Returns whether the test passed. -### assertle(val, condition [, message]) ### +### assertle(a, b [, message]) ### -Saves an error if `val <= condition` is not true with the optional message. +Check that `a <= b` (using the optional `message` if the test fails), +where `a` and `b` are numbers. +Returns whether the test passed. -### assertge(val, condition [, message]) ### +### assertge(a, b [, message]) ### -Saves an error if `val >= condition` is not true with the optional message. +Check that `a >= b` (using the optional `message` if the test fails), +where `a` and `b` are numbers. +Returns whether the test passed. -### asserteq(val, condition [, message]) ### +### asserteq(a, b [, message]) ### -Saves an error if `val == condition` is not true with the optional message. +Check that `a == b` (using the optional `message` if the test fails). +Note that this uses the generic lua equality check, so objects such as tensors +that have the same content but are distinct objects will fail this test; +consider using [assertGeneralEq()](#torch.Tester.assertGeneralEq) instead. +Returns whether the test passed. -### assertne(val, condition [, message]) ### +### assertne(a, b [, message]) ### -Saves an error if `val ~= condition` is not true with the optional message. +Check that `a ~= b` (using the optional `message` if the test fails). +Note that this uses the generic lua inequality check, so objects such as tensors +that have the same content but are distinct objects will pass this test; +consider using [assertGeneralNe()](#torch.Tester.assertGeneralNe) instead. +Returns whether the test passed. + + +### assertalmosteq(a, b [, tolerance] [, message]) ### + +Check that `|a - b| <= tolerance` (using the optional `message` if the +test fails), where `a` and `b` are numbers, and `tolerance` is an optional +number (default `1e-16`). +Returns whether the test passed. -### assertTensorEq(ta, tb, condition [, message]) ### +### assertTensorEq(ta, tb [, tolerance] [, message]) ### -Saves an error if `max(abs(ta-tb)) < condition` is not true with the optional message. +Check that `max(abs(ta - tb)) <= tolerance` (using the optional `message` +if the test fails), where `ta` and `tb` are tensors, and `tolerance` is an +optional number (default `1e-16`). Tensors that are different types or sizes +will cause this check to fail. +Returns whether the test passed. -### assertTensorNe(ta, tb, condition [, message]) ### +### assertTensorNe(ta, tb [, tolerance] [, message]) ### -Saves an error if `max(abs(ta-tb)) >= condition` is not true with the optional message. +Check that `max(abs(ta - tb)) > tolerance` (using the optional `message` +if the test fails), where `ta` and `tb` are tensors, and `tolerance` is an +optional number (default `1e-16`). Tensors that are different types or sizes +will cause this check to pass. +Returns whether the test passed. -### assertTableEq(ta, tb, condition [, message]) ### +### assertTableEq(ta, tb [, tolerance] [, message]) ### -Saves an error if `max(abs(ta-tb)) < condition` is not true with the optional message. +Check that the two tables have the same contents, comparing them +recursively, where objects such as tensors are compared using their contents. +Numbers (such as those appearing in tensors) are considered equal if +their difference is at most the given tolerance. -### assertTableNe(ta, tb, condition [, message]) ### +### assertTableNe(ta, tb [, tolerance] [, message]) ### -Saves an error if `max(abs(ta-tb)) >= condition` is not true with the optional message. +Check that the two tables have distinct contents, comparing them +recursively, where objects such as tensors are compared using their contents. +Numbers (such as those appearing in tensors) are considered equal if +their difference is at most the given tolerance. ### assertError(f [, message]) ### -Saves an error if calling the function f() does not return an error, with the optional message. +Check that calling `f()` (via `pcall`) raises an error (using the +optional `message` if the test fails). +Returns whether the test passed. - -### run() ### + +### assertNoError(f [, message]) ### -Runs all the test functions that are stored using [add()](#torch.Tester.add) function. -While running it reports progress and at the end gives a summary of all errors. +Check that calling `f()` (via `pcall`) does not raise an error (using the +optional `message` if the test fails). +Returns whether the test passed. + +### assertErrorMsg(f, errmsg [, message]) ### +Check that calling `f()` (via `pcall`) raises an error with the specific error +message `errmsg` (using the optional `message` if the test fails). +Returns whether the test passed. + +### assertErrorPattern(f, errPattern [, message]) ### +Check that calling `f()` (via `pcall`) raises an error matching `errPattern` +(using the optional `message` if the test fails). +The matching is done using `string.find`; in particular substrings will match. +Returns whether the test passed. + +### assertErrorObj(f, errcomp [, message]) ### +Check that calling `f()` (via `pcall`) raises an error object `err` such that +calling `errcomp(err)` returns true (using the optional `message` if the test +fails). +Returns whether the test passed. + + +### setEarlyAbort(earlyAbort) ### + +If `earlyAbort == true` then the testing will stop on the first test failure. +By default this is off. + + +### setRethrowErrors(rethrowErrors) ### + +If `rethrowErrors == true` then lua errors encountered during the execution of +the tests will be rethrown, instead of being caught by the tester. +By default this is off. + + +### setSummaryOnly(summaryOnly) ### + +If `summaryOnly == true`, then only the pass / fail status of the tests will be +printed out, rather than full error messages. By default, this is off. + + + +# TestSuite # + +A TestSuite is used in conjunction with [Tester](#torch.Tester.dok). It is +created via `torch.TestSuite()`, and behaves like a plain lua table, +except that it also checks that duplicate tests are not created. +It is recommended that you always use a TestSuite instead of a plain table for +your tests. + +The following example code attempts to add a function with the same name +twice to a TestSuite (a surprisingly common mistake), which gives an error. + +```lua +> test = torch.TestSuite() +> +> function test.myTest() +> -- ... +> end +> +> -- ... +> +> function test.myTest() +> -- ... +> end +torch/TestSuite.lua:16: Test myTest is already defined. +``` diff --git a/init.lua b/init.lua index 978a4afe..cf8680a9 100644 --- a/init.lua +++ b/init.lua @@ -179,6 +179,7 @@ require('torch.File') require('torch.CmdLine') require('torch.FFI') require('torch.Tester') +require('torch.TestSuite') require('torch.test') function torch.totable(obj) diff --git a/rocks/torch-scm-1.rockspec b/rocks/torch-scm-1.rockspec index 22287264..2906ae47 100644 --- a/rocks/torch-scm-1.rockspec +++ b/rocks/torch-scm-1.rockspec @@ -16,7 +16,8 @@ description = { dependencies = { "lua >= 5.1", "paths >= 1.0", - "cwrap >= 1.0" + "cwrap >= 1.0", + "sys >= 1.0" } build = { diff --git a/test/test.lua b/test/test.lua index f86fc728..0640b2f4 100644 --- a/test/test.lua +++ b/test/test.lua @@ -1,7 +1,7 @@ --require 'torch' local mytester -local torchtest = {} +local torchtest = torch.TestSuite() local msize = 100 local precision @@ -725,22 +725,22 @@ function torchtest.addbmm() local res2 = torch.Tensor():resizeAs(res[1]):zero() res2:addbmm(b1,b2) - mytester:assertTensorEq(res2, res:sum(1), precision, 'addbmm result wrong') + mytester:assertTensorEq(res2, res:sum(1)[1], precision, 'addbmm result wrong') res2:addbmm(1,b1,b2) - mytester:assertTensorEq(res2, res:sum(1)*2, precision, 'addbmm result wrong') + mytester:assertTensorEq(res2, res:sum(1)[1]*2, precision, 'addbmm result wrong') res2:addbmm(1,res2,.5,b1,b2) - mytester:assertTensorEq(res2, res:sum(1)*2.5, precision, 'addbmm result wrong') + mytester:assertTensorEq(res2, res:sum(1)[1]*2.5, precision, 'addbmm result wrong') local res3 = torch.addbmm(1,res2,0,b1,b2) mytester:assertTensorEq(res3, res2, precision, 'addbmm result wrong') local res4 = torch.addbmm(1,res2,.5,b1,b2) - mytester:assertTensorEq(res4, res:sum(1)*3, precision, 'addbmm result wrong') + mytester:assertTensorEq(res4, res:sum(1)[1]*3, precision, 'addbmm result wrong') local res5 = torch.addbmm(0,res2,1,b1,b2) - mytester:assertTensorEq(res5, res:sum(1), precision, 'addbmm result wrong') + mytester:assertTensorEq(res5, res:sum(1)[1], precision, 'addbmm result wrong') local res6 = torch.addbmm(.1,res2,.5,b1,b2) mytester:assertTensorEq(res6, res2*.1 + res:sum(1)*.5, precision, 'addbmm result wrong') @@ -1510,8 +1510,10 @@ function torchtest.kthvalue() local mx, ix = torch.kthvalue(x, k) local mxx, ixx = torch.sort(x) - mytester:assertTensorEq(mxx:select(3, k), mx, 0, 'torch.kthvalue value') - mytester:assertTensorEq(ixx:select(3, k), ix, 0, 'torch.kthvalue index') + mytester:assertTensorEq(mxx:select(3, k), mx:select(3, 1), 0, + 'torch.kthvalue value') + mytester:assertTensorEq(ixx:select(3, k), ix:select(3, 1), 0, + 'torch.kthvalue index') end do -- test use of result tensors local k = math.random(1, msize) @@ -1519,15 +1521,19 @@ function torchtest.kthvalue() local ix = torch.LongTensor() torch.kthvalue(mx, ix, x, k) local mxx, ixx = torch.sort(x) - mytester:assertTensorEq(mxx:select(3, k), mx, 0, 'torch.kthvalue value') - mytester:assertTensorEq(ixx:select(3, k), ix, 0, 'torch.kthvalue index') + mytester:assertTensorEq(mxx:select(3, k), mx:select(3, 1), 0, + 'torch.kthvalue value') + mytester:assertTensorEq(ixx:select(3, k), ix:select(3, 1), 0, + 'torch.kthvalue index') end do -- test non-default dim local k = math.random(1, msize) local mx, ix = torch.kthvalue(x, k, 1) local mxx, ixx = torch.sort(x, 1) - mytester:assertTensorEq(mxx:select(1, k), mx, 0, 'torch.kthvalue value') - mytester:assertTensorEq(ixx:select(1, k), ix, 0, 'torch.kthvalue index') + mytester:assertTensorEq(mxx:select(1, k), mx[1], 0, + 'torch.kthvalue value') + mytester:assertTensorEq(ixx:select(1, k), ix[1], 0, + 'torch.kthvalue index') end do -- non-contiguous local y = x:narrow(2, 1, 1) @@ -1557,8 +1563,10 @@ function torchtest.median() local mxx, ixx = torch.sort(x) local ind = math.floor((msize+1)/2) - mytester:assertTensorEq(mxx:select(2, ind), mx, 0, 'torch.median value') - mytester:assertTensorEq(ixx:select(2, ind), ix, 0, 'torch.median index') + mytester:assertTensorEq(mxx:select(2, ind), mx:select(2, 1), 0, + 'torch.median value') + mytester:assertTensorEq(ixx:select(2, ind), ix:select(2, 1), 0, + 'torch.median index') -- Test use of result tensor local mr = torch.Tensor() @@ -1570,8 +1578,10 @@ function torchtest.median() -- Test non-default dim mx, ix = torch.median(x, 1) mxx, ixx = torch.sort(x, 1) - mytester:assertTensorEq(mxx:select(1, ind), mx, 0,'torch.median value') - mytester:assertTensorEq(ixx:select(1, ind), ix, 0,'torch.median index') + mytester:assertTensorEq(mxx:select(1, ind), mx[1], 0, + 'torch.median value') + mytester:assertTensorEq(ixx:select(1, ind), ix[1], 0, + 'torch.median index') -- input unchanged mytester:assertTensorEq(x, x0, 0, 'torch.median modified input') @@ -1658,7 +1668,7 @@ function torchtest.catArray() mytester:assertTensorEq(mx, mxx, 0, 'torch.cat value') end end -function torchtest.sin() +function torchtest.sin_2() local x = torch.rand(msize,msize,msize) local mx = torch.sin(x) local mxx = torch.Tensor() @@ -2205,30 +2215,6 @@ function torchtest.conv3_conv2_eq() mytester:assertlt(maxdiff(o3,o32),precision,'torch.conv3_conv2_eq') end -function torchtest.fxcorr3_fxcorr2_eq() - local ix = math.floor(torch.uniform(20,40)) - local iy = math.floor(torch.uniform(20,40)) - local iz = math.floor(torch.uniform(20,40)) - local kx = math.floor(torch.uniform(5,10)) - local ky = math.floor(torch.uniform(5,10)) - local kz = math.floor(torch.uniform(5,10)) - - local x = torch.rand(ix,iy,iz) - local k = torch.rand(kx,ky,kz) - - local o3 = torch.xcorr3(x,k,'F') - - local o32 = torch.zeros(o3:size()) - - for i=1,x:size(1) do - for j=1,k:size(1) do - o32[i+j-1]:add(torch.xcorr2(x[i],k[k:size(1)-j + 1],'F')) - end - end - - mytester:assertlt(maxdiff(o3,o32),precision,'torch.conv3_conv2_eq') -end - function torchtest.fconv3_fconv2_eq() local ix = math.floor(torch.uniform(20,40)) local iy = math.floor(torch.uniform(20,40)) @@ -2269,27 +2255,6 @@ function torchtest.logical() mytester:asserteq(x:nElement(),all:double():sum() , 'torch.logical') end -function torchtest.TestAsserts() - mytester:assertError(function() error('hello') end, 'assertError: Error not caught') - mytester:assertErrorPattern(function() error('hello') end, '.*ll.*', 'assertError: ".*ll.*" Error not caught') - - local x = torch.rand(100,100)*2-1; - local xx = x:clone(); - mytester:assertTensorEq(x, xx, 1e-16, 'assertTensorEq: not deemed equal') - mytester:assertTensorNe(x, xx+1, 1e-16, 'assertTensorNe: not deemed different') - mytester:assertalmosteq(0, 1e-250, 1e-16, 'assertalmosteq: not deemed different') -end - -function torchtest.BugInAssertTableEq() - local t = {1,2,3} - local tt = {1,2,3} - mytester:assertTableEq(t, tt, 'assertTableEq: not deemed equal') - mytester:assertTableNe(t, {3,2,1}, 'assertTableNe: not deemed different') - mytester:assertTableEq({1,2,{4,5}}, {1,2,{4,5}}, 'assertTableEq: fails on recursive lists') - mytester:assertTableNe(t, {1,2}, 'assertTableNe: different size not deemed different') - mytester:assertTableNe(t, {1,2,3,4}, 'assertTableNe: different size not deemed different') -end - function torchtest.RNGState() local state = torch.getRNGState() local stateCloned = state:clone() @@ -2750,14 +2715,14 @@ function torchtest.classInModule() -- Need a global for this module _mymodule123 = {} local x = torch.class('_mymodule123.myclass') - mytester:assert(x, 'Could not create class in module') + mytester:assert(x ~= nil, 'Could not create class in module') -- Remove the global _G['_mymodule123'] = nil end function torchtest.classNoModule() local x = torch.class('_myclass123') - mytester:assert(x, 'Could not create class in module') + mytester:assert(x ~= nil, 'Could not create class in module') end function torchtest.type() @@ -3003,7 +2968,7 @@ function torchtest.split() mytester:assertTensorEq(tensor:narrow(dim, start, targetSize[i][dim]), split, 0.000001, 'Result content error in split '..i) start = start + targetSize[i][dim] end - mytester:asserteq(#splits,#result, 0, 'Non-consistent output size from split') + mytester:asserteq(#splits, #result, 'Non-consistent output size from split') for i, split in ipairs(splits) do mytester:assertTensorEq(split,result[i], 0, 'Non-consistent outputs from split') end @@ -3136,13 +3101,13 @@ function torchtest.nonzero() table.insert(dst, i) end end - mytester:assertTensorEq(dst1, torch.LongTensor(dst), 0.0, + mytester:assertTensorEq(dst1:select(2, 1), torch.LongTensor(dst), 0.0, "nonzero error") - mytester:assertTensorEq(dst2, torch.LongTensor(dst), 0.0, + mytester:assertTensorEq(dst2:select(2, 1), torch.LongTensor(dst), 0.0, "nonzero error") - --mytester:assertTensorEq(dst3, torch.LongTensor(dst), 0.0, - -- "nonzero error") - mytester:assertTensorEq(dst4, torch.LongTensor(dst), 0.0, + --mytester:assertTensorEq(dst3:select(2, 1), torch.LongTensor(dst), + -- 0.0, "nonzero error") + mytester:assertTensorEq(dst4:select(2, 1), torch.LongTensor(dst), 0.0, "nonzero error") elseif shape:size() == 2 then -- This test will allow through some false positives. It only checks diff --git a/test/test_Tester.lua b/test/test_Tester.lua new file mode 100644 index 00000000..a2833608 --- /dev/null +++ b/test/test_Tester.lua @@ -0,0 +1,626 @@ +require 'torch' + +local tester = torch.Tester() + +local MESSAGE = "a really useful informative error message" + +local subtester = torch.Tester() +-- The message only interests us in case of failure +subtester._success = function(self) return true, MESSAGE end +subtester._failure = function(self, message) return false, message end + +local tests = torch.TestSuite() + +local test_name_passed_to_setUp +local calls_to_setUp = 0 +local calls_to_tearDown = 0 + +local originalIoWrite = io.write +local function disableIoWrite() + io.write = function() end +end +local function enableIoWrite() + io.write = originalIoWrite +end + +local function meta_assert_success(success, message) + tester:assert(success == true, "assert wasn't successful") + tester:assert(string.find(message, MESSAGE) ~= nil, "message doesn't match") +end +local function meta_assert_failure(success, message) + tester:assert(success == false, "assert didn't fail") + tester:assert(string.find(message, MESSAGE) ~= nil, "message doesn't match") +end + +function tests.really_test_assert() + assert((subtester:assert(true, MESSAGE)), + "subtester:assert doesn't actually work!") + assert(not (subtester:assert(false, MESSAGE)), + "subtester:assert doesn't actually work!") +end + +function tests.setEarlyAbort() + disableIoWrite() + + for _, earlyAbort in ipairs{false, true} do + local myTester = torch.Tester() + + local invokedCount = 0 + local myTests = {} + function myTests.t1() + invokedCount = invokedCount + 1 + myTester:assert(false) + end + myTests.t2 = myTests.t1 + + myTester:setEarlyAbort(earlyAbort) + myTester:add(myTests) + pcall(myTester.run, myTester) + + tester:assert(invokedCount == (earlyAbort and 1 or 2), + "wrong number of tests invoked for use with earlyAbort") + end + + enableIoWrite() +end + +function tests.setRethrowErrors() + disableIoWrite() + + local myTester = torch.Tester() + myTester:setRethrowErrors(true) + myTester:add(function() error("a throw") end) + + tester:assertErrorPattern(function() myTester:run() end, + "a throw", + "error should be rethrown") + + enableIoWrite() +end + +function tests.disable() + disableIoWrite() + + for disableCount = 1, 2 do + local myTester = torch.Tester() + local tests = {} + local test1Invoked = false + local test2Invoked = false + function tests.test1() + test1Invoked = true + end + function tests.test2() + test2Invoked = true + end + myTester:add(tests) + + if disableCount == 1 then + myTester:disable('test1'):run() + tester:assert((not test1Invoked) and test2Invoked, + "disabled test shouldn't have been invoked") + else + myTester:disable({'test1', 'test2'}):run() + tester:assert((not test1Invoked) and (not test2Invoked), + "disabled tests shouldn't have been invoked") + end + end + + enableIoWrite() +end + +function tests.assert() + meta_assert_success(subtester:assert(true, MESSAGE)) + meta_assert_failure(subtester:assert(false, MESSAGE)) +end + +local function testEqNe(eqExpected, ...) + if eqExpected then + meta_assert_success(subtester:eq(...)) + meta_assert_failure(subtester:ne(...)) + else + meta_assert_failure(subtester:eq(...)) + meta_assert_success(subtester:ne(...)) + end +end + +--[[ Test :assertGeneralEq and :assertGeneralNe (also known as :eq and :ne). + +Note that in-depth testing of testing of many specific types of data (such as +Tensor) is covered below, when we test specific functions (such as +:assertTensorEq). This just does a general check, as well as testing of testing +of mixed datatypes. +]] +function tests.assertGeneral() + local one = torch.Tensor{1} + + testEqNe(true, one, one, MESSAGE) + testEqNe(false, one, 1, MESSAGE) + testEqNe(true, "hi", "hi", MESSAGE) + testEqNe(true, {one, 1}, {one, 1}, MESSAGE) + testEqNe(true, {{{one}}}, {{{one}}}, MESSAGE) + testEqNe(false, {{{one}}}, {{one}}, MESSAGE) + testEqNe(true, torch.Storage{1}, torch.Storage{1}, MESSAGE) + testEqNe(false, torch.FloatStorage{1}, torch.LongStorage{1}, MESSAGE) + testEqNe(false, torch.Storage{1}, torch.Storage{1, 2}, MESSAGE) + testEqNe(false, "one", 1, MESSAGE) + testEqNe(false, {one}, {one + torch.Tensor{1e-10}}, MESSAGE) + testEqNe(true, {one}, {one + torch.Tensor{1e-10}}, 1e-9, MESSAGE) +end + +function tests.assertlt() + meta_assert_success(subtester:assertlt(1, 2, MESSAGE)) + meta_assert_failure(subtester:assertlt(2, 1, MESSAGE)) + meta_assert_failure(subtester:assertlt(1, 1, MESSAGE)) +end + +function tests.assertgt() + meta_assert_success(subtester:assertgt(2, 1, MESSAGE)) + meta_assert_failure(subtester:assertgt(1, 2, MESSAGE)) + meta_assert_failure(subtester:assertgt(1, 1, MESSAGE)) +end + +function tests.assertle() + meta_assert_success(subtester:assertle(1, 2, MESSAGE)) + meta_assert_failure(subtester:assertle(2, 1, MESSAGE)) + meta_assert_success(subtester:assertle(1, 1, MESSAGE)) +end + +function tests.assertge() + meta_assert_success(subtester:assertge(2, 1, MESSAGE)) + meta_assert_failure(subtester:assertge(1, 2, MESSAGE)) + meta_assert_success(subtester:assertge(1, 1, MESSAGE)) +end + +function tests.asserteq() + meta_assert_success(subtester:asserteq(1, 1, MESSAGE)) + meta_assert_failure(subtester:asserteq(1, 2, MESSAGE)) +end + +function tests.assertalmosteq() + meta_assert_success(subtester:assertalmosteq(1, 1, MESSAGE)) + meta_assert_success(subtester:assertalmosteq(1, 1 + 1e-17, MESSAGE)) + meta_assert_success(subtester:assertalmosteq(1, 2, 2, MESSAGE)) + meta_assert_failure(subtester:assertalmosteq(1, 2, MESSAGE)) + meta_assert_failure(subtester:assertalmosteq(1, 3, 1, MESSAGE)) +end + +function tests.assertne() + meta_assert_success(subtester:assertne(1, 2, MESSAGE)) + meta_assert_failure(subtester:assertne(1, 1, MESSAGE)) +end + +-- The `alsoTestEq` flag is provided to test :eq in addition to :assertTensorEq. +-- The behaviour of the two isn't always the same due to handling of tensors of +-- different dimensions but the same number of elements. +local function testTensorEqNe(eqExpected, alsoTestEq, ...) + if eqExpected then + meta_assert_success(subtester:assertTensorEq(...)) + meta_assert_failure(subtester:assertTensorNe(...)) + if alsoTestEq then + meta_assert_success(subtester:eq(...)) + meta_assert_failure(subtester:ne(...)) + end + else + meta_assert_failure(subtester:assertTensorEq(...)) + meta_assert_success(subtester:assertTensorNe(...)) + if alsoTestEq then + meta_assert_failure(subtester:eq(...)) + meta_assert_success(subtester:ne(...)) + end + end +end + +function tests.assertTensor_types() + local allTypes = { + torch.ByteTensor, + torch.CharTensor, + torch.ShortTensor, + torch.IntTensor, + torch.LongTensor, + torch.FloatTensor, + torch.DoubleTensor, + } + for _, tensor1 in ipairs(allTypes) do + for _, tensor2 in ipairs(allTypes) do + local t1 = tensor1():ones(10) + local t2 = tensor2():ones(10) + testTensorEqNe(tensor1 == tensor2, true, t1, t2, 1e-6, MESSAGE) + end + end + + testTensorEqNe(false, true, torch.FloatTensor(), torch.LongTensor(), MESSAGE) +end + +function tests.assertTensor_sizes() + local t = torch.Tensor() -- no dimensions + local t2 = torch.ones(2) + local t3 = torch.ones(3) + local t12 = torch.ones(1, 2) + assert(subtester._assertTensorEqIgnoresDims == true) -- default state + testTensorEqNe(false, false, t, t2, 1e-6, MESSAGE) + testTensorEqNe(false, false, t, t3, 1e-6, MESSAGE) + testTensorEqNe(false, false, t, t12, 1e-6, MESSAGE) + testTensorEqNe(false, false, t2, t3, 1e-6, MESSAGE) + testTensorEqNe(true, false, t2, t12, 1e-6, MESSAGE) + testTensorEqNe(false, false, t3, t12, 1e-6, MESSAGE) + subtester._assertTensorEqIgnoresDims = false + testTensorEqNe(false, true, t, t2, 1e-6, MESSAGE) + testTensorEqNe(false, true, t, t3, 1e-6, MESSAGE) + testTensorEqNe(false, true, t, t12, 1e-6, MESSAGE) + testTensorEqNe(false, true, t2, t3, 1e-6, MESSAGE) + testTensorEqNe(false, true, t2, t12, 1e-6, MESSAGE) + testTensorEqNe(false, true, t3, t12, 1e-6, MESSAGE) + subtester._assertTensorEqIgnoresDims = true -- reset back +end + +function tests.assertTensor_epsilon() + local t1 = torch.rand(100, 100) + local t2 = torch.rand(100, 100) * 1e-5 + local t3 = t1 + t2 + testTensorEqNe(true, true, t1, t3, 1e-4, MESSAGE) + testTensorEqNe(false, true, t1, t3, 1e-6, MESSAGE) +end + +function tests.assertTensor_arg() + local one = torch.Tensor{1} + + tester:assertErrorPattern( + function() subtester:assertTensorEq(one, 2) end, + "Second argument should be a Tensor") + + -- Test that assertTensorEq support message and tolerance in either ordering + tester:assertNoError( + function() subtester:assertTensorEq(one, one, 0.1, MESSAGE) end) + tester:assertNoError( + function() subtester:assertTensorEq(one, one, MESSAGE, 0.1) end) +end + +function tests.assertTensor() + local t1 = torch.randn(100, 100) + local t2 = t1:clone() + local t3 = torch.randn(100, 100) + testTensorEqNe(true, true, t1, t2, 1e-6, MESSAGE) + testTensorEqNe(false, true, t1, t3, 1e-6, MESSAGE) + testTensorEqNe(true, true, torch.Tensor(), torch.Tensor(), MESSAGE) +end + +-- Check that calling assertTensorEq with two tensors with the same content but +-- different dimensions gives a warning. +function tests.assertTensorDimWarning() + local myTester = torch.Tester() + myTester:add( + function() + myTester:assertTensorEq(torch.Tensor{{1}}, torch.Tensor{1}) + end) + + local warningGiven = false + io.write = function(s) + if string.match(s, 'but different dimensions') then + warningGiven = true + end + end + + myTester:run() + enableIoWrite() + + tester:assert(warningGiven, + "Calling :assertTensorEq({{1}}, {1}) should give a warning") +end + +local function testTableEqNe(eqExpected, ...) + if eqExpected then + meta_assert_success(subtester:assertTableEq(...)) + meta_assert_failure(subtester:assertTableNe(...)) + meta_assert_success(subtester:eq(...)) + meta_assert_failure(subtester:ne(...)) + else + meta_assert_failure(subtester:assertTableEq(...)) + meta_assert_success(subtester:assertTableNe(...)) + meta_assert_failure(subtester:eq(...)) + meta_assert_success(subtester:ne(...)) + end +end + +function tests.assertTable() + testTableEqNe(true, {1, 2, 3}, {1, 2, 3}, MESSAGE) + testTableEqNe(false, {1, 2, 3}, {3, 2, 1}, MESSAGE) + testTableEqNe(true, {1, 2, {4, 5}}, {1, 2, {4, 5}}, MESSAGE) + testTableEqNe(false, {1, 2, 3}, {1,2}, MESSAGE) + testTableEqNe(false, {1, 2, 3}, {1, 2, 3, 4}, MESSAGE) + testTableEqNe(true, {{1}}, {{1}}, MESSAGE) + testTableEqNe(false, {{1}}, {{{1}}}, MESSAGE) + testTableEqNe(true, {false}, {false}, MESSAGE) + testTableEqNe(false, {true}, {false}, MESSAGE) + testTableEqNe(false, {false}, {true}, MESSAGE) + + local tensor = torch.rand(100, 100) + local t1 = {1, "a", key = "value", tensor = tensor, subtable = {"nested"}} + local t2 = {1, "a", key = "value", tensor = tensor, subtable = {"nested"}} + testTableEqNe(true, t1, t2, MESSAGE) + for k, v in pairs(t1) do + local x = "something else" + t2[k] = nil + t2[x] = v + testTableEqNe(false, t1, t2, MESSAGE) + t2[x] = nil + t2[k] = x + testTableEqNe(false, t1, t2, MESSAGE) + t2[k] = v + testTableEqNe(true, t1, t2, MESSAGE) + end +end + +local function good_fn() end +local function bad_fn() error("muahaha!") end + +function tests.assertError() + meta_assert_success(subtester:assertError(bad_fn, MESSAGE)) + meta_assert_failure(subtester:assertError(good_fn, MESSAGE)) +end + +function tests.assertNoError() + meta_assert_success(subtester:assertNoError(good_fn, MESSAGE)) + meta_assert_failure(subtester:assertNoError(bad_fn, MESSAGE)) +end + +function tests.assertErrorPattern() + meta_assert_success(subtester:assertErrorPattern(bad_fn, "haha", MESSAGE)) + meta_assert_failure(subtester:assertErrorPattern(bad_fn, "hehe", MESSAGE)) +end + +function tests.testSuite_duplicateTests() + local function createDuplicateTests() + local tests = torch.TestSuite() + function tests.testThis() end + function tests.testThis() end + end + tester:assertErrorPattern(createDuplicateTests, + "Test testThis is already defined.") +end + +--[[ Returns a Tester with `numSuccess` success cases, `numFailure` failure + cases, and with an error if `hasError` is true. + Success and fail tests are evaluated with tester:eq +]] +local function genDummyTest(numSuccess, numFailure, hasError) + hasError = hasError or false + + local dummyTester = torch.Tester() + local dummyTests = torch.TestSuite() + + if numSuccess > 0 then + function dummyTests.testDummySuccess() + for i = 1, numSuccess do + dummyTester:eq({1}, {1}, '', 0) + end + end + end + + if numFailure > 0 then + function dummyTests.testDummyFailure() + for i = 1, numFailure do + dummyTester:eq({1}, {2}, '', 0) + end + end + end + + if hasError then + function dummyTests.testDummyError() + error('dummy error') + end + end + + return dummyTester:add(dummyTests) +end + +function tests.runStatusAndAssertCounts() + local emptyTest = genDummyTest(0, 0, false) + local sucTest = genDummyTest(1, 0, false) + local multSucTest = genDummyTest(4, 0, false) + local failTest = genDummyTest(0, 1, false) + local errTest = genDummyTest(0, 0, true) + local errFailTest = genDummyTest(0, 1, true) + local errSucTest = genDummyTest(1, 0, true) + local failSucTest = genDummyTest(1, 1, false) + local failSucErrTest = genDummyTest(1, 1, true) + + disableIoWrite() + + local success, msg = pcall(emptyTest.run, emptyTest) + tester:asserteq(success, true, "pcall should succeed for empty tests") + + local success, msg = pcall(sucTest.run, sucTest) + tester:asserteq(success, true, "pcall should succeed for 1 successful test") + + local success, msg = pcall(multSucTest.run, multSucTest) + tester:asserteq(success, true, + "pcall should succeed for 2+ successful tests") + + local success, msg = pcall(failTest.run, failTest) + tester:asserteq(success, false, "pcall should fail for tests with failure") + + local success, msg = pcall(errTest.run, errTest) + tester:asserteq(success, false, "pcall should fail for tests with error") + + local success, msg = pcall(errFailTest.run, errFailTest) + tester:asserteq(success, false, "pcall should fail for error+fail tests") + + local success, msg = pcall(errSucTest.run, errSucTest) + tester:asserteq(success, false, "pcall should fail for error+success tests") + + local success, msg = pcall(failSucTest.run, failSucTest) + tester:asserteq(success, false, "pcall should fail for fail+success tests") + + local success, msg = pcall(failSucErrTest.run, failSucErrTest) + tester:asserteq(success, false, + "pcall should fail for fail+success+err test") + + enableIoWrite() + + tester:asserteq(emptyTest.countasserts, 0, + "emptyTest should have 0 asserts") + tester:asserteq(sucTest.countasserts, 1, "sucTest should have 1 assert") + tester:asserteq(multSucTest.countasserts, 4, + "multSucTest should have 4 asserts") + tester:asserteq(failTest.countasserts, 1, "failTest should have 1 assert") + tester:asserteq(errTest.countasserts, 0, "errTest should have 0 asserts") + tester:asserteq(errFailTest.countasserts, 1, + "errFailTest should have 1 assert") + tester:asserteq(errSucTest.countasserts, 1, + "errSucTest should have 0 asserts") + tester:asserteq(failSucTest.countasserts, 2, + "failSucTest should have 2 asserts") +end + +function tests.checkNestedTestsForbidden() + disableIoWrite() + + local myTester = torch.Tester() + local myTests = {{function() end}} + tester:assertErrorPattern(function() myTester:add(myTests) end, + "Nested sets", + "tester should forbid adding nested test sets") + + enableIoWrite() +end + +function tests.checkWarningOnAssertObject() + -- This test checks that calling assert with an object generates a warning + local myTester = torch.Tester() + local myTests = {} + function myTests.assertAbuse() + myTester:assert({}) + end + myTester:add(myTests) + + local warningGiven = false + io.write = function(s) + if string.match(s, 'should only be used for boolean') then + warningGiven = true + end + end + + myTester:run() + enableIoWrite() + + tester:assert(warningGiven, "Should warn on calling :assert(object)") +end + +function tests.checkWarningOnAssertNeObject() + -- This test checks that calling assertne with two objects generates warning + local myTester = torch.Tester() + local myTests = {} + function myTests.assertAbuse() + myTester:assertne({}, {}) + end + myTester:add(myTests) + + local warningGiven = false + io.write = function(s) + if string.match(s, 'assertne should only be used to compare basic') then + warningGiven = true + end + end + + myTester:run() + enableIoWrite() + + tester:assert(warningGiven, "Should warn on calling :assertne(obj, obj)") +end + +function tests.checkWarningOnExtraAssertArguments() + -- This test checks that calling assert with extra args gives a lua error + local myTester = torch.Tester() + local myTests = {} + function myTests.assertAbuse() + myTester:assert(true, "some message", "extra argument") + end + myTester:add(myTests) + + local errorGiven = false + io.write = function(s) + if string.match(s, 'Unexpected arguments') then + errorGiven = true + end + end + tester:assertError(function() myTester:run() end) + enableIoWrite() + + tester:assert(errorGiven, ":assert should fail on extra arguments") +end + +function tests.checkWarningOnUsingTable() + -- Checks that if we don't use a TestSuite then gives a warning + local myTester = torch.Tester() + local myTests = {} + myTester:add(myTests) + + local errorGiven = false + io.write = function(s) + if string.match(s, 'use TestSuite rather than plain lua table') then + errorGiven = true + end + end + myTester:run() + + enableIoWrite() + tester:assert(errorGiven, "Using a plain lua table for testsuite should warn") +end + +function tests.checkMaxAllowedSetUpAndTearDown() + -- Checks can have at most 1 set-up and at most 1 tear-down function + local function f() end + local myTester = torch.Tester() + + for _, name in ipairs({'_setUp', '_tearDown'}) do + tester:assertNoError(function() myTester:add(f, name) end, + "Adding 1 set-up / tear-down should be fine") + tester:assertErrorPattern(function() myTester:add(f, name) end, + "Only one", + "Adding second set-up / tear-down should fail") + end +end + +function tests.test_setUp() + tester:asserteq(test_name_passed_to_setUp, 'test_setUp') + for key, value in pairs(tester.tests) do + tester:assertne(key, '_setUp') + end +end + +function tests.test_tearDown() + for key, value in pairs(tester.tests) do + tester:assertne(key, '_tearDown') + end +end + +function tests._setUp(name) + test_name_passed_to_setUp = name + calls_to_setUp = calls_to_setUp + 1 +end + +function tests._tearDown(name) + calls_to_tearDown = calls_to_tearDown + 1 +end + +tester:add(tests):run() + +-- Additional tests to check that _setUp and _tearDown were called. +local test_count = 0 +for _ in pairs(tester.tests) do + test_count = test_count + 1 +end +local postTests = torch.TestSuite() +local postTester = torch.Tester() + +function postTests.test_setUp(tester) + postTester:asserteq(calls_to_setUp, test_count, + "Expected " .. test_count .. " calls to _setUp") +end + +function postTests.test_tearDown() + postTester:asserteq(calls_to_tearDown, test_count, + "Expected " .. test_count .. " calls to _tearDown") +end + +postTester:add(postTests):run() diff --git a/test/test_qr.lua b/test/test_qr.lua index c00c6046..c850c3fe 100644 --- a/test/test_qr.lua +++ b/test/test_qr.lua @@ -2,7 +2,7 @@ -- torch.qr(), torch.geqrf() and torch.orgqr(). local torch = require 'torch' local tester = torch.Tester() -local tests = {} +local tests = torch.TestSuite() -- torch.qr() with result tensors given. local function qrInPlace(tensorFunc) diff --git a/test/test_sharedmem.lua b/test/test_sharedmem.lua index 9f594fea..14cdeaf0 100644 --- a/test/test_sharedmem.lua +++ b/test/test_sharedmem.lua @@ -1,7 +1,7 @@ require 'torch' local tester = torch.Tester() -local tests = {} +local tests = torch.TestSuite() local function createSharedMemStorage(name, size, storageType) local storageType = storageType or 'FloatStorage' diff --git a/test/test_writeObject.lua b/test/test_writeObject.lua index 8bccf100..ccf7eba2 100644 --- a/test/test_writeObject.lua +++ b/test/test_writeObject.lua @@ -2,7 +2,7 @@ require 'torch' local myTester = torch.Tester() -local tests = {} +local tests = torch.TestSuite() -- checks that an object can be written and unwritten @@ -91,7 +91,7 @@ function tests.test_error_msg() end local ok, msg = pcall(torch.save, 'saved.t7', evil_func) myTester:assert(not ok) - myTester:assert(msg:find('at <%?>%.outer%.theinner%.baz%.torch')) + myTester:assert(msg:find('at <%?>%.outer%.theinner%.baz%.torch') ~= nil) end function tests.test_warning_msg()