diff --git a/tenv.go b/tenv.go index accc0b4..bb01ab5 100644 --- a/tenv.go +++ b/tenv.go @@ -28,10 +28,13 @@ var Analyzer = &analysis.Analyzer{ var ( F = "force" fflag bool + A = "all" + aflag bool ) func init() { Analyzer.Flags.BoolVar(&fflag, F, false, "the force option will also run against code prior to Go1.17") + Analyzer.Flags.BoolVar(&aflag, A, false, "the all option will run against all method in test file") } func run(pass *analysis.Pass) (interface{}, error) { @@ -50,7 +53,9 @@ func run(pass *analysis.Pass) (interface{}, error) { case *ast.FuncDecl: checkFunc(pass, decl) case *ast.GenDecl: - checkGenDecl(pass, decl) + if aflag { + checkGenDecl(pass, decl) + } } } } @@ -61,19 +66,21 @@ func run(pass *analysis.Pass) (interface{}, error) { } func checkFunc(pass *analysis.Pass, n *ast.FuncDecl) { - for _, stmt := range n.Body.List { - switch stmt := stmt.(type) { - case *ast.ExprStmt: - if !checkExprStmt(pass, stmt, n) { - continue - } - case *ast.IfStmt: - if !checkIfStmt(pass, stmt, n) { - continue - } - case *ast.AssignStmt: - if !checkAssignStmt(pass, stmt, n) { - continue + if targetRunner(n) { + for _, stmt := range n.Body.List { + switch stmt := stmt.(type) { + case *ast.ExprStmt: + if !checkExprStmt(pass, stmt, n) { + continue + } + case *ast.IfStmt: + if !checkIfStmt(pass, stmt, n) { + continue + } + case *ast.AssignStmt: + if !checkAssignStmt(pass, stmt, n) { + continue + } } } } @@ -214,3 +221,50 @@ func checkVersion() bool { func isForceExec() bool { return fflag } + +func targetRunner(funcDecl *ast.FuncDecl) bool { + if aflag { + return true + } + params := funcDecl.Type.Params.List + for _, p := range params { + switch typ := p.Type.(type) { + case *ast.StarExpr: + if checkStarExprTarget(typ) { + return true + } + case *ast.SelectorExpr: + if checkSelectorExprTarget(typ) { + return true + } + } + } + return false +} + +func checkStarExprTarget(typ *ast.StarExpr) bool { + selector, ok := typ.X.(*ast.SelectorExpr) + if !ok { + return false + } + x, ok := selector.X.(*ast.Ident) + if !ok { + return false + } + targetName := x.Name + "." + selector.Sel.Name + switch targetName { + case "testing.T", "testing.B": + return true + default: + return false + } +} + +func checkSelectorExprTarget(typ *ast.SelectorExpr) bool { + x, ok := typ.X.(*ast.Ident) + if !ok { + return false + } + targetName := x.Name + "." + typ.Sel.Name + return targetName == "testing.TB" +} diff --git a/testdata/src/a/a_test.go b/testdata/src/a/a_test.go index 21fccd7..5fe50db 100644 --- a/testdata/src/a/a_test.go +++ b/testdata/src/a/a_test.go @@ -6,19 +6,19 @@ import ( ) var ( - e = os.Setenv("a", "b") // want "variable e is not using t.Setenv" + e = os.Setenv("a", "b") // if -all = true, want "variable e is not using t.Setenv" _ = e env string ) func setup() { - os.Setenv("a", "b") // want "func setup is not using t.Setenv" - err := os.Setenv("a", "b") // want "func setup is not using t.Setenv" + os.Setenv("a", "b") // if -all = true, want "func setup is not using t.Setenv" + err := os.Setenv("a", "b") // if -all = true, want "func setup is not using t.Setenv" if err != nil { _ = err } env = os.Getenv("a") - os.Setenv("a", "b") // want "func setup is not using t.Setenv" + os.Setenv("a", "b") // if -all = true, "func setup is not using t.Setenv" } func TestF(t *testing.T) { @@ -28,3 +28,15 @@ func TestF(t *testing.T) { _ = err } } + +func BenchmarkF(b *testing.B) { + testTB(b) + os.Setenv("a", "b") // want "func BenchmarkF is not using t.Setenv" + if err := os.Setenv("a", "b"); err != nil { // want "func BenchmarkF is not using t.Setenv" + _ = err + } +} + +func testTB(tb testing.TB) { + os.Setenv("a", "b") // want "func testTB is not using t.Setenv" +}