From e6ec6681d19fd4c6081a34a666d2075bd1aa9d61 Mon Sep 17 00:00:00 2001 From: Iskander Sharipov Date: Sun, 7 Nov 2021 02:06:07 +0300 Subject: [PATCH] ruleguard: add support for local functions This feature is useful for rule filters readability improvements. Instead of copying a complex `Where()` expression several times, one can now use a local function literal to define that filter operation and use it inside `Where()` expressions. Here is an example: ```go func preferFprint(m dsl.Matcher) { isFmtPackage := func(v dsl.Var) bool { return v.Text == "fmt" && v.Object.Is(`PkgName`) } m.Match(`$w.Write([]byte($fmt.Sprint($*args)))`). Where(m["w"].Type.Implements("io.Writer") && isFmtPackage(m["fmt"])). Suggest("fmt.Fprint($w, $args)"). Report(`fmt.Fprint($w, $args) should be preferred to the $$`) m.Match(`$w.Write([]byte($fmt.Sprintf($*args)))`). Where(m["w"].Type.Implements("io.Writer") && isFmtPackage(m["fmt"])). Suggest("fmt.Fprintf($w, $args)"). Report(`fmt.Fprintf($w, $args) should be preferred to the $$`) // ...etc } ``` Note that we used `isFmtPackage` in more than 1 rule. Functions can accept almost arbitrary params, but there are some restrictions on what kinds of arguments they can receive right now. These arguments work: * Matcher var expressions like `m["varname"]` * Basic literals like `"foo"`, `104`, `5.2` * Constants --- Makefile | 2 +- analyzer/analyzer_test.go | 1 + analyzer/testdata/src/gocritic/rules.go | 13 +- analyzer/testdata/src/localfunc/rules.go | 66 +++++++++ analyzer/testdata/src/localfunc/target.go | 67 +++++++++ analyzer/testdata/src/localfunc/target2.go | 7 + go.mod | 1 + go.sum | 3 + ruleguard/debug_test.go | 14 ++ ruleguard/irconv/irconv.go | 160 ++++++++++++++++++++- ruleguard/ruleguard_error_test.go | 32 ++++- 11 files changed, 352 insertions(+), 14 deletions(-) create mode 100644 analyzer/testdata/src/localfunc/rules.go create mode 100644 analyzer/testdata/src/localfunc/target.go create mode 100644 analyzer/testdata/src/localfunc/target2.go diff --git a/Makefile b/Makefile index b0a0c2aa..342e5059 100644 --- a/Makefile +++ b/Makefile @@ -27,7 +27,7 @@ test-release: @echo "everything is OK" lint: - curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(GOPATH_DIR)/bin v1.30.0 + curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(GOPATH_DIR)/bin v1.43.0 $(GOPATH_DIR)/bin/golangci-lint run ./... go build -o go-ruleguard ./cmd/ruleguard ./go-ruleguard -debug-imports -rules rules.go ./... diff --git a/analyzer/analyzer_test.go b/analyzer/analyzer_test.go index 86c8b244..ce3051f1 100644 --- a/analyzer/analyzer_test.go +++ b/analyzer/analyzer_test.go @@ -42,6 +42,7 @@ var tests = []struct { {name: "comments"}, {name: "stdlib"}, {name: "uber"}, + {name: "localfunc"}, {name: "goversion", flags: map[string]string{"go": "1.16"}}, } diff --git a/analyzer/testdata/src/gocritic/rules.go b/analyzer/testdata/src/gocritic/rules.go index b88ddd6a..12919f2c 100644 --- a/analyzer/testdata/src/gocritic/rules.go +++ b/analyzer/testdata/src/gocritic/rules.go @@ -170,21 +170,22 @@ func appendAssign(m dsl.Matcher) { //doc:before w.Write([]byte(fmt.Sprintf("%x", 10))) //doc:after fmt.Fprintf(w, "%x", 10) func preferFprint(m dsl.Matcher) { + isFmtPackage := func(v dsl.Var) bool { + return v.Text == "fmt" && v.Object.Is(`PkgName`) + } + m.Match(`$w.Write([]byte($fmt.Sprint($*args)))`). - Where(m["w"].Type.Implements("io.Writer") && - m["fmt"].Text == "fmt" && m["fmt"].Object.Is(`PkgName`)). + Where(m["w"].Type.Implements("io.Writer") && isFmtPackage(m["fmt"])). Suggest("fmt.Fprint($w, $args)"). Report(`fmt.Fprint($w, $args) should be preferred to the $$`) m.Match(`$w.Write([]byte($fmt.Sprintf($*args)))`). - Where(m["w"].Type.Implements("io.Writer") && - m["fmt"].Text == "fmt" && m["fmt"].Object.Is(`PkgName`)). + Where(m["w"].Type.Implements("io.Writer") && isFmtPackage(m["fmt"])). Suggest("fmt.Fprintf($w, $args)"). Report(`fmt.Fprintf($w, $args) should be preferred to the $$`) m.Match(`$w.Write([]byte($fmt.Sprintln($*args)))`). - Where(m["w"].Type.Implements("io.Writer") && - m["fmt"].Text == "fmt" && m["fmt"].Object.Is(`PkgName`)). + Where(m["w"].Type.Implements("io.Writer") && isFmtPackage(m["fmt"])). Suggest("fmt.Fprintln($w, $args)"). Report(`fmt.Fprintln($w, $args) should be preferred to the $$`) } diff --git a/analyzer/testdata/src/localfunc/rules.go b/analyzer/testdata/src/localfunc/rules.go new file mode 100644 index 00000000..091731e8 --- /dev/null +++ b/analyzer/testdata/src/localfunc/rules.go @@ -0,0 +1,66 @@ +// +build ignore + +package gorules + +import "github.com/quasilyte/go-ruleguard/dsl" + +func testRules(m dsl.Matcher) { + bothConst := func(x, y dsl.Var) bool { + return x.Const && y.Const + } + m.Match(`test("both const", $x, $y)`). + Where(bothConst(m["x"], m["y"])). + Report(`true`) + + intValue := func(x dsl.Var, val int) bool { + return x.Value.Int() == val + } + m.Match(`test("== 10", $x)`). + Where(intValue(m["x"], 10)). + Report(`true`) + + isZero := func(x dsl.Var) bool { return x.Value.Int() == 0 } + m.Match(`test("== 0", $x)`). + Where(isZero(m["x"])). + Report(`true`) + + // Testing closure-captured m variable. + fmtIsImported := func() bool { + return m.File().Imports(`fmt`) + } + m.Match(`test("fmt is imported")`). + Where(fmtIsImported()). + Report(`true`) + + // Testing explicitly passed matcher. + ioutilIsImported := func(m2 dsl.Matcher) bool { + return m2.File().Imports(`io/ioutil`) + } + m.Match(`test("ioutil is imported")`). + Where(ioutilIsImported(m)). + Report(`true`) + + // Test precedence after the macro expansion. + isSimpleExpr := func(v dsl.Var) bool { + return v.Const || v.Node.Is(`Ident`) + } + m.Match(`test("check precedence", $x, $y)`). + Where(isSimpleExpr(m["x"]) && m["y"].Text == "err"). + Report(`true`) + + // Test variables referenced through captured m. + isStringLit := func() bool { + return m["x"].Node.Is(`BasicLit`) && m["x"].Type.Is(`string`) + } + m.Match(`test("is string", $x)`). + Where(isStringLit()). + Report(`true`) + + // Test predicate applied to different matcher vars. + isPureCall := func(v dsl.Var) bool { + return v.Node.Is(`CallExpr`) && v.Pure + } + m.Match(`test("is pure call", $x, $y)`). + Where(isPureCall(m["x"]) && isPureCall(m["y"])). + Report(`true`) +} diff --git a/analyzer/testdata/src/localfunc/target.go b/analyzer/testdata/src/localfunc/target.go new file mode 100644 index 00000000..b8c9fc6c --- /dev/null +++ b/analyzer/testdata/src/localfunc/target.go @@ -0,0 +1,67 @@ +package localfunc + +import ( + "fmt" + "io/ioutil" +) + +func test(args ...interface{}) {} + +func f() interface{} { return nil } + +func _() { + fmt.Println("ok") + _ = ioutil.Discard + + var i int + var err error + + test("both const", 1, 2) // want `true` + test("both const", 1, 2+2) // want `true` + test("both const", i, 2) + test("both const", 1, i) + test("both const", i, i) + + test("== 10", 10) // want `true` + test("== 10", 9+1) // want `true` + test("== 10", 11) + test("== 10", i) + + test("== 0", 0) // want `true` + test("== 0", 1-1) // want `true` + test("== 0", 11) + test("== 0", i) + + test("fmt is imported") // want `true` + + test("ioutil is imported") // want `true` + + test("check precedence", 1, err) // want `true` + test("check precedence", 1+2, err) // want `true` + test("check precedence", i, err) // want `true` + test("check precedence", err, err) // want `true` + test("check precedence", f(), err) + test("check precedence", 1) + test("check precedence", 1+2) + test("check precedence", i) + test("check precedence", err) + test("check precedence", f()) + test("check precedence", 1, nil) + test("check precedence", 1+2, nil) + test("check precedence", i, nil) + test("check precedence", err, nil) + test("check precedence", f(), nil) + + test("is string", "yes") // want `true` + test("is string", `yes`) // want `true` + test("is string", 1) + test("is string", i) + + test("is pure call", int(0), int(1)) // want `true` + test("is pure call", string("f"), int(1)) // want `true` + test("is pure call", f(), f()) + test("is pure call", int(0), 1) + test("is pure call", 0, int(1)) + test("is pure call", f(), int(1)) + test("is pure call", 1, 1) +} diff --git a/analyzer/testdata/src/localfunc/target2.go b/analyzer/testdata/src/localfunc/target2.go new file mode 100644 index 00000000..1c641a24 --- /dev/null +++ b/analyzer/testdata/src/localfunc/target2.go @@ -0,0 +1,7 @@ +package localfunc + +func _() { + test("fmt is imported") + + test("ioutil is imported") +} diff --git a/go.mod b/go.mod index bea2bb14..498f3590 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/quasilyte/go-ruleguard go 1.15 require ( + github.com/go-toolsmith/astcopy v1.0.0 github.com/go-toolsmith/astequal v1.0.1 github.com/google/go-cmp v0.5.2 github.com/quasilyte/go-ruleguard/dsl v0.3.10 diff --git a/go.sum b/go.sum index 1b5ec4e9..2d9e8c2b 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,6 @@ +github.com/go-toolsmith/astcopy v1.0.0 h1:OMgl1b1MEpjFQ1m5ztEO06rz5CUd3oBv9RF7+DyvdG8= +github.com/go-toolsmith/astcopy v1.0.0/go.mod h1:vrgyG+5Bxrnz4MZWPF+pI4R8h3qKRjjyvV/DSez4WVQ= +github.com/go-toolsmith/astequal v1.0.0/go.mod h1:H+xSiq0+LtiDC11+h1G32h7Of5O3CYFJ99GVbS5lDKY= github.com/go-toolsmith/astequal v1.0.1 h1:JbSszi42Jiqu36Gnf363HWS9MTEAz67vTQLponh3Moc= github.com/go-toolsmith/astequal v1.0.1/go.mod h1:4oGA3EZXTVItV/ipGiOx7NWkY5veFfcsOJVS2YxltLw= github.com/go-toolsmith/strparse v1.0.0 h1:Vcw78DnpCAKlM20kSbAyO4mPfJn/lyYA4BJUDxe2Jb4= diff --git a/ruleguard/debug_test.go b/ruleguard/debug_test.go index 65cba16c..32e2de24 100644 --- a/ruleguard/debug_test.go +++ b/ruleguard/debug_test.go @@ -158,6 +158,20 @@ func TestDebug(t *testing.T) { ` $x []string: []string{"x"}`, }, }, + + `isConst := func(v dsl.Var) bool { return v.Const }; m.Match("_ = $x").Where(isConst(m["x"]) && !m["x"].Type.Is("string"))`: { + `_ = 10`: nil, + + `_ = "str"`: { + `input.go:4: [rules.go:5] rejected by !m["x"].Type.Is("string")`, + ` $x string: "str"`, + }, + + `_ = f()`: { + `input.go:4: [rules.go:5] rejected by isConst(m["x"])`, + ` $x interface{}: f()`, + }, + }, } loadRulesFromExpr := func(e *Engine, s string) { diff --git a/ruleguard/irconv/irconv.go b/ruleguard/irconv/irconv.go index ceb6e816..386e2d80 100644 --- a/ruleguard/irconv/irconv.go +++ b/ruleguard/irconv/irconv.go @@ -10,6 +10,7 @@ import ( "strconv" "strings" + "github.com/go-toolsmith/astcopy" "github.com/quasilyte/go-ruleguard/ruleguard/goutil" "github.com/quasilyte/go-ruleguard/ruleguard/ir" "golang.org/x/tools/go/ast/astutil" @@ -52,13 +53,20 @@ type convError struct { err error } +type localMacroFunc struct { + name string + params []string + template ast.Expr +} + type converter struct { types *types.Info pkg *types.Package fset *token.FileSet src []byte - group *ir.RuleGroup + group *ir.RuleGroup + groupFuncs []localMacroFunc dslPkgname string // The local name of the "ruleguard/dsl" package (usually its just "dsl") } @@ -171,6 +179,7 @@ func (conv *converter) convertRuleGroup(decl *ast.FuncDecl) *ir.RuleGroup { Line: conv.fset.Position(decl.Name.Pos()).Line, } conv.group = result + conv.groupFuncs = conv.groupFuncs[:0] result.Name = decl.Name.String() result.MatcherName = decl.Type.Params.List[0].Names[0].String() @@ -181,6 +190,11 @@ func (conv *converter) convertRuleGroup(decl *ast.FuncDecl) *ir.RuleGroup { seenRules := false for _, stmt := range decl.Body.List { + if assign, ok := stmt.(*ast.AssignStmt); ok && assign.Tok == token.DEFINE { + conv.localDefine(assign) + continue + } + if _, ok := stmt.(*ast.DeclStmt); ok { continue } @@ -208,6 +222,146 @@ func (conv *converter) convertRuleGroup(decl *ast.FuncDecl) *ir.RuleGroup { return result } +func (conv *converter) findLocalMacro(call *ast.CallExpr) *localMacroFunc { + fn, ok := call.Fun.(*ast.Ident) + if !ok { + return nil + } + for i := range conv.groupFuncs { + if conv.groupFuncs[i].name == fn.Name { + return &conv.groupFuncs[i] + } + } + return nil +} + +func (conv *converter) expandMacro(macro *localMacroFunc, call *ast.CallExpr) ir.FilterExpr { + // Check that call args are OK. + // Since "function calls" are implemented as a macro expansion here, + // we don't allow arguments that have a non-trivial evaluation. + isSafe := func(arg ast.Expr) bool { + switch arg := astutil.Unparen(arg).(type) { + case *ast.BasicLit, *ast.Ident: + return true + + case *ast.IndexExpr: + mapIdent, ok := astutil.Unparen(arg.X).(*ast.Ident) + if !ok { + return false + } + if mapIdent.Name != conv.group.MatcherName { + return false + } + key, ok := astutil.Unparen(arg.Index).(*ast.BasicLit) + if !ok || key.Kind != token.STRING { + return false + } + return true + + default: + return false + } + } + args := map[string]ast.Expr{} + for i, arg := range call.Args { + paramName := macro.params[i] + if !isSafe(arg) { + panic(conv.errorf(arg, "unsupported/too complex %s argument", paramName)) + } + args[paramName] = astutil.Unparen(arg) + } + + body := astcopy.Expr(macro.template) + expanded := astutil.Apply(body, nil, func(cur *astutil.Cursor) bool { + if ident, ok := cur.Node().(*ast.Ident); ok { + arg, ok := args[ident.Name] + if ok { + cur.Replace(arg) + return true + } + } + // astcopy above will copy the AST tree, but it won't update + // the associated types.Info map of const values. + // We'll try to solve that issue at least partially here. + if lit, ok := cur.Node().(*ast.BasicLit); ok { + switch lit.Kind { + case token.STRING: + val, err := strconv.Unquote(lit.Value) + if err == nil { + conv.types.Types[lit] = types.TypeAndValue{ + Type: types.Typ[types.UntypedString], + Value: constant.MakeString(val), + } + } + case token.INT: + val, err := strconv.ParseInt(lit.Value, 0, 64) + if err == nil { + conv.types.Types[lit] = types.TypeAndValue{ + Type: types.Typ[types.UntypedInt], + Value: constant.MakeInt64(val), + } + } + case token.FLOAT: + val, err := strconv.ParseFloat(lit.Value, 64) + if err == nil { + conv.types.Types[lit] = types.TypeAndValue{ + Type: types.Typ[types.UntypedFloat], + Value: constant.MakeFloat64(val), + } + } + } + } + return true + }) + + return conv.convertFilterExpr(expanded.(ast.Expr)) +} + +func (conv *converter) localDefine(assign *ast.AssignStmt) { + if len(assign.Lhs) != 1 || len(assign.Rhs) != 1 { + panic(conv.errorf(assign, "multi-value := is not supported")) + } + lhs, ok := assign.Lhs[0].(*ast.Ident) + if !ok { + panic(conv.errorf(assign.Lhs[0], "only simple ident lhs is supported")) + } + rhs := assign.Rhs[0] + fn, ok := rhs.(*ast.FuncLit) + if !ok { + panic(conv.errorf(rhs, "only func literals are supported on the rhs")) + } + typ := conv.types.TypeOf(fn).(*types.Signature) + isBoolResult := typ.Results() != nil && + typ.Results().Len() == 1 && + typ.Results().At(0).Type() == types.Typ[types.Bool] + if !isBoolResult { + var loc ast.Node = fn.Type + if fn.Type.Results != nil { + loc = fn.Type.Results + } + panic(conv.errorf(loc, "only funcs returning bool are supported")) + } + if len(fn.Body.List) != 1 { + panic(conv.errorf(fn.Body, "only simple 1 return statement funcs are supported")) + } + stmt, ok := fn.Body.List[0].(*ast.ReturnStmt) + if !ok { + panic(conv.errorf(fn.Body.List[0], "expected a return statement, found %T", fn.Body.List[0])) + } + var params []string + for _, field := range fn.Type.Params.List { + for _, id := range field.Names { + params = append(params, id.Name) + } + } + macro := localMacroFunc{ + name: lhs.Name, + params: params, + template: stmt.Results[0], + } + conv.groupFuncs = append(conv.groupFuncs, macro) +} + func (conv *converter) doMatcherImport(call *ast.CallExpr) { pkgPath := conv.parseStringArg(call.Args[0]) pkgName := path.Base(pkgPath) @@ -518,6 +672,10 @@ func (conv *converter) convertFilterExprImpl(e ast.Expr) ir.FilterExpr { return ir.FilterExpr{Op: ir.FilterVarFilterOp, Value: op.varName, Args: args} } + if macro := conv.findLocalMacro(e); macro != nil { + return conv.expandMacro(macro, e) + } + args := convertExprList(e.Args) switch op.path { case "Value.Int": diff --git a/ruleguard/ruleguard_error_test.go b/ruleguard/ruleguard_error_test.go index df1f2397..bc2c556f 100644 --- a/ruleguard/ruleguard_error_test.go +++ b/ruleguard/ruleguard_error_test.go @@ -255,7 +255,32 @@ func TestParseRuleError(t *testing.T) { { `m.Match("func[]").Report("$$")`, - `\Qparse match pattern: cannot parse expr: 1:5: expected '(', found '['`, + `(?:expected '\(', found '\['|empty type parameter list)`, + }, + + { + `x := 10; println(x)`, + `\Qonly func literals are supported on the rhs`, + }, + + { + `x, y := 10, 20; println(x, y)`, + `\Qmulti-value := is not supported`, + }, + + { + `f := func() int { return 10 }; f()`, + `\Qonly funcs returning bool are supported`, + }, + + { + `f := func() bool { v := true; return v }; f()`, + `\Qonly simple 1 return statement funcs are supported`, + }, + + { + `f := func(x int) bool { return x == 0 }; m.Match("($x)").Where(f(1+1)).Report("")`, + `\Qunsupported/too complex x argument`, }, } @@ -295,11 +320,6 @@ func TestParseFilterError(t *testing.T) { `unsupported expr: true`, }, - { - `m["x"].Text == 5`, - `cannot convert 5 (untyped int constant) to string`, - }, - { `m["x"].Text.Matches("(12")`, `error parsing regexp: missing closing )`,