Skip to content

Commit

Permalink
use fact
Browse files Browse the repository at this point in the history
  • Loading branch information
sylvia7788 committed Jul 13, 2022
1 parent 30ee69c commit e7f8bd0
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 185 deletions.
148 changes: 83 additions & 65 deletions contextcheck.go
Expand Up @@ -6,7 +6,6 @@ import (
"go/types"
"strconv"
"strings"
"sync"

"github.com/gostaticanalysis/analysisutil"
"golang.org/x/tools/go/analysis"
Expand All @@ -23,6 +22,7 @@ func NewAnalyzer() *analysis.Analyzer {
Requires: []*analysis.Analyzer{
buildssa.Analyzer,
},
FactTypes: []analysis.Fact{(*ctxFact)(nil)},
}
}

Expand All @@ -39,27 +39,38 @@ const (
CtxInOut = CtxIn | CtxOut
)

var (
checkedMap = make(map[string]bool)
checkedMapLock sync.RWMutex
c *collector
)
type resInfo struct {
Valid bool
Funcs []string
}

type ctxFact map[string]resInfo

func (*ctxFact) String() string { return "ctxCheck" }
func (*ctxFact) AFact() {}

type runner struct {
pass *analysis.Pass
ctxTyp *types.Named
ctxPTyp *types.Pointer
cmpPath string
skipFile map[*ast.File]bool

currentFact ctxFact
}

func NewRun(pkgs []*packages.Package) func(pass *analysis.Pass) (interface{}, error) {
c = newCollector(pkgs)
m := make(map[string]bool)
for _, pkg := range pkgs {
m[strings.Split(pkg.PkgPath, "/")[0]] = true
}
return func(pass *analysis.Pass) (interface{}, error) {
defer c.DecUse(pass)
// skip different repo
if !m[strings.Split(pass.Pkg.Path(), "/")[0]] {
return nil, nil
}

r := new(runner)
r.run(pass)
new(runner).run(pass)
return nil, nil
}
}
Expand Down Expand Up @@ -90,21 +101,33 @@ func (r *runner) run(pass *analysis.Pass) {
}

r.skipFile = make(map[*ast.File]bool)
r.currentFact = make(ctxFact)

var tmpFuncs []*ssa.Function
for _, f := range funcs {
// skip checked function
key := f.RelString(nil)
_, ok := getValue(key)
if ok {
if _, ok := r.currentFact[key]; ok {
continue
}

if !r.checkIsEntry(f, f.Pos()) {
// record the result of nomal function
checkingMap := make(map[string]bool)
checkingMap[key] = true
r.setFact(key, r.checkFuncWithoutCtx(f, checkingMap), f.Name())
continue
}

tmpFuncs = append(tmpFuncs, f)
}

for _, f := range tmpFuncs {
r.checkFuncWithCtx(f)
setValue(key, true)
}

if len(r.currentFact) > 0 {
pass.ExportPackageFact(&r.currentFact)
}
}

Expand Down Expand Up @@ -269,16 +292,6 @@ func (r *runner) collectCtxRef(f *ssa.Function) (refMap map[ssa.Instruction]bool
return
}

func (r *runner) buildPkg(f *ssa.Function) (ff *ssa.Function) {
if f.Blocks != nil {
ff = f
return
}

ff = c.GetFunction(f)
return
}

func (r *runner) checkIsSameRepo(s string) bool {
return strings.HasPrefix(s, r.cmpPath+"/")
}
Expand Down Expand Up @@ -313,31 +326,10 @@ func (r *runner) checkFuncWithCtx(f *ssa.Function) {
}

key := ff.RelString(nil)
valid, ok := getValue(key)
res, ok := r.getValue(key, ff)
if ok {
if !valid {
r.pass.Reportf(instr.Pos(), "Function `%s` should pass the context parameter", ff.Name())
}
continue
}

// check is thunk or bound
if strings.HasSuffix(key, "$thunk") || strings.HasSuffix(key, "$bound") {
continue
}

// if ff has no ctx, start deep traversal check
if !r.checkIsEntry(ff, instr.Pos()) {
if ff = r.buildPkg(ff); ff == nil {
continue
}

checkingMap := make(map[string]bool)
checkingMap[key] = true
valid := r.checkFuncWithoutCtx(ff, checkingMap)
setValue(key, valid)
if !valid {
r.pass.Reportf(instr.Pos(), "Function `%s` should pass the context parameter", ff.Name())
if !res.Valid {
r.pass.Reportf(instr.Pos(), "Function `%s` should pass the context parameter", strings.Join(reverse(res.Funcs), "->"))
}
}
}
Expand All @@ -346,6 +338,7 @@ func (r *runner) checkFuncWithCtx(f *ssa.Function) {

func (r *runner) checkFuncWithoutCtx(f *ssa.Function, checkingMap map[string]bool) (ret bool) {
ret = true
orgKey := f.RelString(nil)
for _, b := range f.Blocks {
for _, instr := range b.Instrs {
tp, ok := r.getCtxType(instr)
Expand All @@ -362,7 +355,6 @@ func (r *runner) checkFuncWithoutCtx(f *ssa.Function, checkingMap map[string]boo
if tp&CtxInField == 0 {
ret = false
}
continue
}

ff := r.getFunction(instr)
Expand All @@ -371,11 +363,13 @@ func (r *runner) checkFuncWithoutCtx(f *ssa.Function, checkingMap map[string]boo
}

key := ff.RelString(nil)
valid, ok := getValue(key)
res, ok := r.getValue(key, ff)
if ok {
if !valid {
if !res.Valid {
ret = false
r.pass.Reportf(instr.Pos(), "Function `%s` should pass the context parameter", ff.Name())

// save the call link
r.setFact(orgKey, res.Valid, res.Funcs...)
}
continue
}
Expand All @@ -386,21 +380,21 @@ func (r *runner) checkFuncWithoutCtx(f *ssa.Function, checkingMap map[string]boo
}

if !r.checkIsEntry(ff, instr.Pos()) {
// handler ring call
if checkingMap[key] {
// cannot get info from fact, skip
if ff.Blocks == nil {
continue
}
checkingMap[key] = true

if ff = r.buildPkg(ff); ff == nil {
// handler ring call
if checkingMap[key] {
continue
}
checkingMap[key] = true

valid := r.checkFuncWithoutCtx(ff, checkingMap)
setValue(key, valid)
r.setFact(orgKey, valid, ff.Name())
if !valid {
ret = false
r.pass.Reportf(instr.Pos(), "Function `%s` should pass the context parameter", ff.Name())
}
}
}
Expand Down Expand Up @@ -501,15 +495,39 @@ func (r *runner) isCtxType(tp types.Type) bool {
return types.Identical(tp, r.ctxTyp) || types.Identical(tp, r.ctxPTyp)
}

func getValue(key string) (valid, ok bool) {
checkedMapLock.RLock()
valid, ok = checkedMap[key]
checkedMapLock.RUnlock()
func (r *runner) getValue(key string, f *ssa.Function) (res resInfo, ok bool) {
res, ok = r.currentFact[key]
if ok {
return
}

if f.Pkg == nil {
return
}

var fact ctxFact
if r.pass.ImportPackageFact(f.Pkg.Pkg, &fact) {
res, ok = fact[key]
}
return
}

func setValue(key string, valid bool) {
checkedMapLock.Lock()
checkedMap[key] = valid
checkedMapLock.Unlock()
func (r *runner) setFact(key string, valid bool, funcs ...string) {
r.currentFact[key] = resInfo{
Valid: valid,
Funcs: append(r.currentFact[key].Funcs, funcs...),
}
}

func reverse(arr1 []string) (arr2 []string) {
l := len(arr1)
if l == 0 {
return
}
arr2 = make([]string, l)
for i := 0; i <= l/2; i++ {
arr2[i] = arr1[l-1-i]
arr2[l-1-i] = arr1[i]
}
return
}
5 changes: 4 additions & 1 deletion contextcheck_test.go
Expand Up @@ -6,10 +6,13 @@ import (

"github.com/sylvia7788/contextcheck"
"golang.org/x/tools/go/analysis/analysistest"
"golang.org/x/tools/go/packages"
)

func Test(t *testing.T) {
log.SetFlags(log.Lshortfile)
testdata := analysistest.TestData()
analysistest.Run(t, testdata, contextcheck.NewAnalyzer(), "a")
analyzer := contextcheck.NewAnalyzer()
analyzer.Run = contextcheck.NewRun([]*packages.Package{{PkgPath: "a"}})
analysistest.Run(t, testdata, analyzer, "a")
}
116 changes: 0 additions & 116 deletions dep.go

This file was deleted.

0 comments on commit e7f8bd0

Please sign in to comment.