Skip to content

Commit

Permalink
Add assert.fails to check if a function fails. (#66)
Browse files Browse the repository at this point in the history
Example:

```python
def might_fail(input):
  if input < 10:
    fail("input is less than 10")
  else:
    return input

def test_might_fail(t):
  t.assert.fails(might_fail, 3)
  # passes

  t.assert.fails(might_fail, 14)
  # assertion failure: function might_fail should have failed
```
  • Loading branch information
mmoriarity-stripe authored and jmillikin-stripe committed Dec 6, 2019
1 parent 40659e5 commit dfe37ba
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 12 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS
Expand Up @@ -10,5 +10,6 @@ Benjamin Yolken <yolken@stripe.com>
Dmitry Ilyevsky <ilyevsky@gmail.com> # GM Cruise LLC
Isaac Diamond <idiamond@stripe.com>
John Millikin <jmillikin@stripe.com>
Matt Moriarity <mmoriarity@stripe.com>

# Please alphabetize new entries.
41 changes: 41 additions & 0 deletions internal/go/skycfg/assert.go
Expand Up @@ -21,6 +21,7 @@ import (
"sort"

"go.starlark.net/starlark"
"go.starlark.net/starlarkstruct"
"go.starlark.net/syntax"
)

Expand All @@ -39,6 +40,9 @@ func AssertModule() *TestContext {
for op, str := range tokenToString {
ctx.Attrs[str] = starlark.NewBuiltin(fmt.Sprintf("assert.%s", str), ctx.AssertBinaryImpl(op))
}

ctx.Attrs["fails"] = starlark.NewBuiltin("assert.fails", ctx.AssertFails)

return ctx
}

Expand All @@ -47,6 +51,7 @@ type assertionError struct {
op *syntax.Token
val1 starlark.Value
val2 starlark.Value
msg string
callStack starlark.CallStack
}

Expand All @@ -55,6 +60,11 @@ func (err assertionError) Error() string {
position := callStack.At(0).Pos.String()
backtrace := callStack.String()

// use custom message if provided
if err.msg != "" {
return fmt.Sprintf("[%s] assertion failed: %s\n%s", position, err.msg, backtrace)
}

// straight boolean assertions like assert.true(false)
if err.op == nil {
return fmt.Sprintf("[%s] assertion failed\n%s", position, backtrace)
Expand Down Expand Up @@ -153,6 +163,37 @@ func (t *TestContext) AssertBinaryImpl(op syntax.Token) func(thread *starlark.Th
}
}

func (t *TestContext) AssertFails(thread *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
if len(args) < 1 {
return nil, fmt.Errorf("assert.fails: missing argument for fn")
}

failFn := args[0]
failArgs := args[1:]

if _, err := starlark.Call(thread, failFn, failArgs, kwargs); err != nil {
if _, ok := err.(*starlark.EvalError); ok {
// an eval error means the function failed and the assertion passes
// return a struct with `message` as the string from the error
s := starlark.NewBuiltin("struct", starlarkstruct.Make)
result := starlarkstruct.FromStringDict(s, starlark.StringDict{
"message": starlark.String(err.Error()),
})
return result, nil
}

return nil, err
}

// if no error was returned, the assertion fails
err := assertionError{
msg: fmt.Sprintf("function %s should have failed", failFn.(starlark.Callable).Name()),
callStack: thread.CallStack(),
}
t.Failures = append(t.Failures, err)
return nil, err
}

var tokenToString = map[syntax.Token]string{
syntax.LT: "lesser",
syntax.GT: "greater",
Expand Down
48 changes: 47 additions & 1 deletion internal/go/skycfg/assert_test.go
Expand Up @@ -254,6 +254,51 @@ func TestBinaryAsserts(t *testing.T) {
}
}

func TestAssertFails(t *testing.T) {
testCases := []assertUnaryTestCase{
assertUnaryTestCase{
assertTestCaseImpl: assertTestCaseImpl{
expFailure: false,
expError: false,
},
val: `fail, "this is an expected failure"`,
},
assertUnaryTestCase{
assertTestCaseImpl: assertTestCaseImpl{
expFailure: true,
expFailureMsg: "assertion failed: function print should have failed",
expError: false,
},
val: `print, "this should have failed"`,
},
assertUnaryTestCase{
assertTestCaseImpl: assertTestCaseImpl{
expFailure: false,
expError: true,
expErrorMsg: "invalid call of non-function (int)",
},
val: `3, "this should be an error"`,
},
assertUnaryTestCase{
assertTestCaseImpl: assertTestCaseImpl{
expFailure: false,
expError: true,
expErrorMsg: "assert.fails: missing argument for fn",
},
val: ``,
},
}

for _, testCase := range testCases {
cmd := fmt.Sprintf(
`t.assert.fails(%s)`,
testCase.val,
)

evalAndReportResults(t, cmd, testCase)
}
}

func TestMultipleAssertionErrors(t *testing.T) {
thread := new(starlark.Thread)
assertModule := AssertModule()
Expand Down Expand Up @@ -300,7 +345,8 @@ func evalAndReportResults(t *testing.T, cmd string, testCase assertTestCase) {
}),
}
env := starlark.StringDict{
"t": testCtx,
"t": testCtx,
"fail": Fail,
}

_, err := starlark.Eval(
Expand Down
19 changes: 19 additions & 0 deletions internal/go/skycfg/fail.go
@@ -0,0 +1,19 @@
package skycfg

import (
"fmt"

"go.starlark.net/starlark"
)

var Fail = starlark.NewBuiltin("fail", failImpl)

func failImpl(t *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
var msg string
if err := starlark.UnpackPositionalArgs(fn.Name(), args, kwargs, 1, &msg); err != nil {
return nil, err
}
callStack := t.CallStack()
callStack.Pop()
return nil, fmt.Errorf("[%s] %s\n%s", callStack.At(0).Pos, msg, callStack.String())
}
12 changes: 1 addition & 11 deletions skycfg.go
Expand Up @@ -175,7 +175,7 @@ func UnstablePredeclaredModules(r unstableProtoRegistry) starlark.StringDict {
func predeclaredModules() (modules starlark.StringDict, proto *impl.ProtoModule) {
proto = impl.NewProtoModule(nil /* TODO: registry from options */)
modules = starlark.StringDict{
"fail": starlark.NewBuiltin("fail", skyFail),
"fail": impl.Fail,
"hash": impl.HashModule(),
"json": impl.JsonModule(),
"proto": proto,
Expand Down Expand Up @@ -426,13 +426,3 @@ func (c *Config) Tests() []*Test {
func skyPrint(t *starlark.Thread, msg string) {
fmt.Fprintf(os.Stderr, "[%v] %s\n", t.CallFrame(1).Pos, msg)
}

func skyFail(t *starlark.Thread, fn *starlark.Builtin, args starlark.Tuple, kwargs []starlark.Tuple) (starlark.Value, error) {
var msg string
if err := starlark.UnpackPositionalArgs(fn.Name(), args, kwargs, 1, &msg); err != nil {
return nil, err
}
callStack := t.CallStack()
callStack.Pop()
return nil, fmt.Errorf("[%s] %s\n%s", callStack.At(0).Pos, msg, callStack.String())
}

0 comments on commit dfe37ba

Please sign in to comment.