Skip to content

Commit

Permalink
feat: support check r.Context()
Browse files Browse the repository at this point in the history
  • Loading branch information
sylvia7788 committed Aug 15, 2022
1 parent a63b6f4 commit ca63215
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 116 deletions.
173 changes: 63 additions & 110 deletions contextcheck.go
Expand Up @@ -38,17 +38,14 @@ const (
CtxIn int = 1 << iota // ctx in function's param
CtxOut // ctx in function's results
CtxInField // ctx in function's field param
HttpRes // http.ResponseWriter in function's param
HttpReq // *http.Request in function's param

HttpHandler = HttpRes | HttpReq
)

const (
EntryWithCtx int = 1 << iota // has ctx in
EntryWithHttpHandler // is http handler
type entryType int

Entry = EntryWithCtx | EntryWithHttpHandler
const (
EntryNone entryType = iota
EntryWithCtx // has ctx in
EntryWithHttpHandler // is http handler
)

type resInfo struct {
Expand Down Expand Up @@ -109,7 +106,7 @@ func (r *runner) run(pass *analysis.Pass) {

type entryInfo struct {
f *ssa.Function // entryfunc
tp int // entrytype
tp entryType // entrytype
}
var tmpFuncs []entryInfo
for _, f := range funcs {
Expand All @@ -119,7 +116,7 @@ func (r *runner) run(pass *analysis.Pass) {
continue
}

if entryType := r.checkIsEntry(f); entryType&Entry == 0 {
if entryType := r.checkIsEntry(f); entryType == EntryNone {
// record the result of nomal function
checkingMap := make(map[string]bool)
checkingMap[key] = true
Expand Down Expand Up @@ -161,7 +158,7 @@ func (r *runner) getRequiedType(pssa *buildssa.SSA, path, name string) (obj *typ
func (r *runner) collectHttpTyps(pssa *buildssa.SSA) {
objRes, pobjRes, ok := r.getRequiedType(pssa, httpPkg, httpRes)
if ok {
r.httpResTyps = append(r.httpResTyps, objRes, pobjRes, types.NewPointer(pobjRes))
r.httpResTyps = append(r.httpResTyps, objRes, pobjRes)
}

objReq, pobjReq, ok := r.getRequiedType(pssa, httpPkg, httpReq)
Expand Down Expand Up @@ -201,27 +198,26 @@ func (r *runner) noImportedContextAndHttp(f *ssa.Function) (ret bool) {
return true
}

func (r *runner) checkIsEntry(f *ssa.Function) (entryType int) {
func (r *runner) checkIsEntry(f *ssa.Function) entryType {
if r.noImportedContextAndHttp(f) {
return
return EntryNone
}

ctxIn, ctxOut := r.checkIsCtx(f)
if ctxOut {
// skip the function which generate ctx
return
return EntryNone
} else if ctxIn {
// has ctx in, ignore *http.Request.Context()
entryType |= EntryWithCtx
return
return EntryWithCtx
}

// check is `func handler(w http.ResponseWriter, r *http.Request) {}`
if r.checkIsHttpHandler(f) {
entryType |= EntryWithHttpHandler
return EntryWithHttpHandler
}

return
return EntryNone
}

func (r *runner) checkIsCtx(f *ssa.Function) (in, out bool) {
Expand Down Expand Up @@ -259,39 +255,12 @@ func (r *runner) checkIsHttpHandler(f *ssa.Function) bool {
return false
}

// must has http.ResponseWriter and *http.Request in param or freevar
var tp int

// check params
// must be `func f(w http.ResponseWriter, r *http.Request) {}`
tuple := f.Signature.Params()
for i := 0; i < tuple.Len(); i++ {
if r.isCtxType(tuple.At(i).Type()) {
return false
} else if r.isHttpReqType(tuple.At(i).Type()) {
tp |= HttpReq
} else if r.isHttpResType(tuple.At(i).Type()) {
tp |= HttpRes
}
if tp == HttpHandler {
return true
}
}

// check freevars
for _, param := range f.FreeVars {
if r.isCtxType(param.Type()) {
return false
} else if r.isHttpReqType(param.Type()) {
tp |= HttpReq
} else if r.isHttpResType(param.Type()) {
tp |= HttpRes
}
if tp == HttpHandler {
return true
}
if tuple.Len() != 2 {
return false
}

return false
return r.isHttpResType(tuple.At(0).Type()) && r.isHttpReqType(tuple.At(1).Type())
}

func (r *runner) collectCtxRef(f *ssa.Function, isHttpHandler bool) (refMap map[ssa.Instruction]bool, ok bool) {
Expand Down Expand Up @@ -358,15 +327,21 @@ func (r *runner) collectCtxRef(f *ssa.Function, isHttpHandler bool) (refMap map[
}
}

for _, param := range f.Params {
if r.isCtxType(param.Type()) {
checkRefs(param, false)
if isHttpHandler {
for _, v := range r.getHttpReqCtx(f) {
checkRefs(v, false)
}
} else {
for _, param := range f.Params {
if r.isCtxType(param.Type()) {
checkRefs(param, false)
}
}
}

for _, param := range f.FreeVars {
if r.isCtxType(param.Type()) {
checkRefs(param, false)
for _, param := range f.FreeVars {
if r.isCtxType(param.Type()) {
checkRefs(param, false)
}
}
}

Expand All @@ -386,14 +361,6 @@ func (r *runner) collectCtxRef(f *ssa.Function, isHttpHandler bool) (refMap map[
}
}

if !isHttpHandler {
return
}

for _, v := range r.getHttpReqCtx(f) {
checkRefs(v, false)
}

return
}

Expand Down Expand Up @@ -421,40 +388,34 @@ func (r *runner) getHttpReqCtx(f *ssa.Function) (rets []ssa.Value) {
checkInstr = func(instr ssa.Instruction, fromAddr bool) {
switch i := instr.(type) {
case ssa.CallInstruction:
// r.Context() only has one recv
if len(i.Common().Args) != 1 {
break
}

// find r.Context()
if r.getCallInstrCtxType(i)&CtxOut != CtxOut {
break
}

for _, v := range i.Common().Args {
if !r.isHttpReqType(v.Type()) {
continue
}

f := r.getFunction(instr)
if f == nil {
continue
}

// check is r.Context
if f.Signature.Recv() != nil && r.isHttpReqType(f.Signature.Recv().Type()) && f.Name() == ctxName {
// collect the return of r.Context
rets = append(rets, i.Value())
}
// check is r.Context
f := r.getFunction(instr)
if f == nil || f.Name() != ctxName {
break
}
if f.Signature.Recv() != nil {
// collect the return of r.Context
rets = append(rets, i.Value())
}
case *ssa.Store:
if !fromAddr {
checkRefs(i.Addr, true)
}
case *ssa.UnOp:
if r.isHttpReqType(i.Type()) {
checkRefs(i, false)
}
case *ssa.MakeClosure:
checkRefs(i, false)
case *ssa.Phi:
if r.isHttpReqType(i.Type()) {
checkRefs(i, false)
}
checkRefs(i, false)
case *ssa.MakeClosure:
case *ssa.Extract:
// http.Request can only be input
}
Expand All @@ -463,20 +424,15 @@ func (r *runner) getHttpReqCtx(f *ssa.Function) (rets []ssa.Value) {
for _, param := range f.Params {
if r.isHttpReqType(param.Type()) {
checkRefs(param, false)
}
}

for _, param := range f.FreeVars {
if r.isHttpReqType(param.Type()) {
checkRefs(param, false)
break
}
}

return
}

func (r *runner) checkFuncWithCtx(f *ssa.Function, tp int) {
isHttpHandler := tp&EntryWithHttpHandler != 0
func (r *runner) checkFuncWithCtx(f *ssa.Function, tp entryType) {
isHttpHandler := tp == EntryWithHttpHandler
refMap, ok := r.collectCtxRef(f, isHttpHandler)
if !ok {
return
Expand All @@ -496,15 +452,14 @@ func (r *runner) checkFuncWithCtx(f *ssa.Function, tp int) {

if tp&CtxIn != 0 {
if !refMap[instr] {
r.pass.Reportf(instr.Pos(), "Non-inherited new context, use function like `context.WithXXX` or `r.Context` instead")
if isHttpHandler {
r.pass.Reportf(instr.Pos(), "Non-inherited new context, use function like `context.WithXXX` or `r.Context` instead")
} else {
r.pass.Reportf(instr.Pos(), "Non-inherited new context, use function like `context.WithXXX` instead")
}
}
}

// only check if the ctx used in the current function is r.Context()
if isHttpHandler {
continue
}

ff := r.getFunction(instr)
if ff == nil {
continue
Expand Down Expand Up @@ -564,13 +519,13 @@ func (r *runner) checkFuncWithoutCtx(f *ssa.Function, checkingMap map[string]boo
continue
}

if entryType := r.checkIsEntry(ff); entryType&Entry == 0 {
if entryType := r.checkIsEntry(ff); entryType == EntryNone {
// cannot get info from fact, skip
if ff.Blocks == nil {
continue
}

// handler ring call
// handler cycle call
if checkingMap[key] {
continue
}
Expand Down Expand Up @@ -681,23 +636,21 @@ func (r *runner) isCtxType(tp types.Type) bool {
}

func (r *runner) isHttpResType(tp types.Type) bool {
var ok bool
for _, v := range r.httpResTyps {
if ok = types.Identical(v, v); ok {
break
if ok := types.Identical(v, v); ok {
return true
}
}
return ok
return false
}

func (r *runner) isHttpReqType(tp types.Type) bool {
var ok bool
for _, v := range r.httpReqTyps {
if ok = types.Identical(tp, v); ok {
break
if ok := types.Identical(tp, v); ok {
return true
}
}
return ok
return false
}

func (r *runner) getValue(key string, f *ssa.Function) (res resInfo, ok bool) {
Expand Down
21 changes: 15 additions & 6 deletions testdata/src/a/a.go
Expand Up @@ -48,7 +48,7 @@ func f1(ctx context.Context) {
f2(ctx)
}(ctx)

f2(context.Background()) // want "Non-inherited new context, use function like `context.WithXXX` or `r.Context` instead"
f2(context.Background()) // want "Non-inherited new context, use function like `context.WithXXX` instead"

thunk := MyInt.F
thunk(0)
Expand All @@ -66,7 +66,7 @@ func f3() {
func f4(ctx context.Context) {
f2(ctx)
ctx = context.Background()
f2(ctx) // want "Non-inherited new context, use function like `context.WithXXX` or `r.Context` instead"
f2(ctx) // want "Non-inherited new context, use function like `context.WithXXX` instead"
}

func f5(ctx context.Context) {
Expand Down Expand Up @@ -104,19 +104,28 @@ func f9(w http.ResponseWriter, r *http.Request) {
f8(context.Background(), w, r) // want "Non-inherited new context, use function like `context.WithXXX` or `r.Context` instead"
}

func f10() {
func f10(in bool, w http.ResponseWriter, r *http.Request) {
f8(r.Context(), w, r)
f8(context.Background(), w, r)
}

func f11() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
f9(w, r)
f8(r.Context(), w, r)
f8(context.Background(), w, r) // want "Non-inherited new context, use function like `context.WithXXX` or `r.Context` instead"

f9(w, r)

// f10 should be like `func f10(ctx context.Context, in bool, w http.ResponseWriter, r *http.Request)`
f10(true, w, r) // want "Function `f10` should pass the context parameter"
})
}

/* ----------------- generics ----------------- */

type MySlice[T int | float32] []T

func (s MySlice[T]) f11(ctx context.Context) T {
func (s MySlice[T]) f12(ctx context.Context) T {
f3() // generics, Block is nil, wont report

var sum T
Expand All @@ -126,7 +135,7 @@ func (s MySlice[T]) f11(ctx context.Context) T {
return sum
}

func f12[T int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64 | float32 | float64](ctx context.Context, a, b T) T {
func f13[T int | int8 | int16 | int32 | int64 | uint | uint8 | uint16 | uint32 | uint64 | float32 | float64](ctx context.Context, a, b T) T {
f3() // generics, Block is nil, wont report

if a > b {
Expand Down

0 comments on commit ca63215

Please sign in to comment.