From dfe37ba20c1650c26135d0fed2d50fd3e2b46ab9 Mon Sep 17 00:00:00 2001 From: mmoriarity-stripe <51032382+mmoriarity-stripe@users.noreply.github.com> Date: Fri, 6 Dec 2019 10:55:40 -0700 Subject: [PATCH] Add `assert.fails` to check if a function fails. (#66) Example: ```python def might_fail(input): if input < 10: fail("input is less than 10") else: return input def test_might_fail(t): t.assert.fails(might_fail, 3) # passes t.assert.fails(might_fail, 14) # assertion failure: function might_fail should have failed ``` --- CONTRIBUTORS | 1 + internal/go/skycfg/assert.go | 41 ++++++++++++++++++++++++++ internal/go/skycfg/assert_test.go | 48 ++++++++++++++++++++++++++++++- internal/go/skycfg/fail.go | 19 ++++++++++++ skycfg.go | 12 +------- 5 files changed, 109 insertions(+), 12 deletions(-) create mode 100644 internal/go/skycfg/fail.go diff --git a/CONTRIBUTORS b/CONTRIBUTORS index b743bff..31e5b62 100644 --- a/CONTRIBUTORS +++ b/CONTRIBUTORS @@ -10,5 +10,6 @@ Benjamin Yolken Dmitry Ilyevsky # GM Cruise LLC Isaac Diamond John Millikin +Matt Moriarity # Please alphabetize new entries. diff --git a/internal/go/skycfg/assert.go b/internal/go/skycfg/assert.go index 9286143..04d40e0 100644 --- a/internal/go/skycfg/assert.go +++ b/internal/go/skycfg/assert.go @@ -21,6 +21,7 @@ import ( "sort" "go.starlark.net/starlark" + "go.starlark.net/starlarkstruct" "go.starlark.net/syntax" ) @@ -39,6 +40,9 @@ func AssertModule() *TestContext { for op, str := range tokenToString { ctx.Attrs[str] = starlark.NewBuiltin(fmt.Sprintf("assert.%s", str), ctx.AssertBinaryImpl(op)) } + + ctx.Attrs["fails"] = starlark.NewBuiltin("assert.fails", ctx.AssertFails) + return ctx } @@ -47,6 +51,7 @@ type assertionError struct { op *syntax.Token val1 starlark.Value val2 starlark.Value + msg string callStack starlark.CallStack } @@ -55,6 +60,11 @@ func (err assertionError) Error() string { position := callStack.At(0).Pos.String() backtrace := callStack.String() + // use custom message if provided + if err.msg != "" { + return fmt.Sprintf("[%s] assertion failed: %s\n%s", position, err.msg, backtrace) + } + // straight boolean assertions like assert.true(false) if err.op == nil { return fmt.Sprintf("[%s] assertion failed\n%s", position, backtrace) @@ -153,6 +163,37 @@ func (t *TestContext) AssertBinaryImpl(op syntax.Token) func(thread *starlark.Th } } +func (t *TestContext) AssertFails(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { + if len(args) < 1 { + return nil, fmt.Errorf("assert.fails: missing argument for fn") + } + + failFn := args[0] + failArgs := args[1:] + + if _, err := starlark.Call(thread, failFn, failArgs, kwargs); err != nil { + if _, ok := err.(*starlark.EvalError); ok { + // an eval error means the function failed and the assertion passes + // return a struct with `message` as the string from the error + s := starlark.NewBuiltin("struct", starlarkstruct.Make) + result := starlarkstruct.FromStringDict(s, starlark.StringDict{ + "message": starlark.String(err.Error()), + }) + return result, nil + } + + return nil, err + } + + // if no error was returned, the assertion fails + err := assertionError{ + msg: fmt.Sprintf("function %s should have failed", failFn.(starlark.Callable).Name()), + callStack: thread.CallStack(), + } + t.Failures = append(t.Failures, err) + return nil, err +} + var tokenToString = map[syntax.Token]string{ syntax.LT: "lesser", syntax.GT: "greater", diff --git a/internal/go/skycfg/assert_test.go b/internal/go/skycfg/assert_test.go index 8168810..8ef76ab 100644 --- a/internal/go/skycfg/assert_test.go +++ b/internal/go/skycfg/assert_test.go @@ -254,6 +254,51 @@ func TestBinaryAsserts(t *testing.T) { } } +func TestAssertFails(t *testing.T) { + testCases := []assertUnaryTestCase{ + assertUnaryTestCase{ + assertTestCaseImpl: assertTestCaseImpl{ + expFailure: false, + expError: false, + }, + val: `fail, "this is an expected failure"`, + }, + assertUnaryTestCase{ + assertTestCaseImpl: assertTestCaseImpl{ + expFailure: true, + expFailureMsg: "assertion failed: function print should have failed", + expError: false, + }, + val: `print, "this should have failed"`, + }, + assertUnaryTestCase{ + assertTestCaseImpl: assertTestCaseImpl{ + expFailure: false, + expError: true, + expErrorMsg: "invalid call of non-function (int)", + }, + val: `3, "this should be an error"`, + }, + assertUnaryTestCase{ + assertTestCaseImpl: assertTestCaseImpl{ + expFailure: false, + expError: true, + expErrorMsg: "assert.fails: missing argument for fn", + }, + val: ``, + }, + } + + for _, testCase := range testCases { + cmd := fmt.Sprintf( + `t.assert.fails(%s)`, + testCase.val, + ) + + evalAndReportResults(t, cmd, testCase) + } +} + func TestMultipleAssertionErrors(t *testing.T) { thread := new(starlark.Thread) assertModule := AssertModule() @@ -300,7 +345,8 @@ func evalAndReportResults(t *testing.T, cmd string, testCase assertTestCase) { }), } env := starlark.StringDict{ - "t": testCtx, + "t": testCtx, + "fail": Fail, } _, err := starlark.Eval( diff --git a/internal/go/skycfg/fail.go b/internal/go/skycfg/fail.go new file mode 100644 index 0000000..4911eff --- /dev/null +++ b/internal/go/skycfg/fail.go @@ -0,0 +1,19 @@ +package skycfg + +import ( + "fmt" + + "go.starlark.net/starlark" +) + +var Fail = starlark.NewBuiltin("fail", failImpl) + +func failImpl(t *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { + var msg string + if err := starlark.UnpackPositionalArgs(fn.Name(), args, kwargs, 1, &msg); err != nil { + return nil, err + } + callStack := t.CallStack() + callStack.Pop() + return nil, fmt.Errorf("[%s] %s\n%s", callStack.At(0).Pos, msg, callStack.String()) +} diff --git a/skycfg.go b/skycfg.go index b80be06..1181dec 100644 --- a/skycfg.go +++ b/skycfg.go @@ -175,7 +175,7 @@ func UnstablePredeclaredModules(r unstableProtoRegistry) starlark.StringDict { func predeclaredModules() (modules starlark.StringDict, proto *impl.ProtoModule) { proto = impl.NewProtoModule(nil /* TODO: registry from options */) modules = starlark.StringDict{ - "fail": starlark.NewBuiltin("fail", skyFail), + "fail": impl.Fail, "hash": impl.HashModule(), "json": impl.JsonModule(), "proto": proto, @@ -426,13 +426,3 @@ func (c *Config) Tests() []*Test { func skyPrint(t *starlark.Thread, msg string) { fmt.Fprintf(os.Stderr, "[%v] %s\n", t.CallFrame(1).Pos, msg) } - -func skyFail(t *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) { - var msg string - if err := starlark.UnpackPositionalArgs(fn.Name(), args, kwargs, 1, &msg); err != nil { - return nil, err - } - callStack := t.CallStack() - callStack.Pop() - return nil, fmt.Errorf("[%s] %s\n%s", callStack.At(0).Pos, msg, callStack.String()) -}