Skip to content

Commit

Permalink
Refactor SQL rules for better extensibility
Browse files Browse the repository at this point in the history
Remove hardwired assumption and heuristics on index of arg taking a SQL
string, be explicit about it instead.
  • Loading branch information
scop committed Jul 22, 2022
1 parent 0212c83 commit a18c3a0
Showing 1 changed file with 58 additions and 19 deletions.
77 changes: 58 additions & 19 deletions rules/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
package rules

import (
"fmt"
"go/ast"
"regexp"
"strings"

"github.com/securego/gosec/v2"
)
Expand All @@ -30,6 +30,51 @@ type sqlStatement struct {
patterns []*regexp.Regexp
}

var sqlCallIdents = map[string]map[string]int{

This comment has been minimized.

Copy link
@ccojocar

ccojocar Jul 29, 2022

Member

Can you explain what's the reason behind using this 0/1 in a map?

This comment has been minimized.

Copy link
@scop

scop Jul 29, 2022

Author Contributor

0 and 1 are the indexes of the argument we are interested in. So for Exec, we're interested in the first (index 0) argument, for ExecContext we're interested in the second (index 1), etc.

"*database/sql.DB": {
"Exec": 0,
"ExecContext": 1,
"Query": 0,
"QueryContext": 1,
"QueryRow": 0,
"QueryRowContext": 1,
"Prepare": 0,
"PrepareContext": 1,
},
"*database/sql.Tx": {
"Exec": 0,
"ExecContext": 1,
"Query": 0,
"QueryContext": 1,
"QueryRow": 0,
"QueryRowContext": 1,
"Prepare": 0,
"PrepareContext": 1,
},
}

// findQueryArg locates the argument taking raw SQL
func findQueryArg(call *ast.CallExpr, ctx *gosec.Context) (ast.Expr, error) {
typeName, fnName, err := gosec.GetCallInfo(call, ctx)
if err != nil {
return nil, err
}
i := -1
if ni, ok := sqlCallIdents[typeName]; ok {
if i, ok = ni[fnName]; !ok {
i = -1

This comment has been minimized.

Copy link
@ccojocar

ccojocar Jul 29, 2022

Member

Are you planning to increment i?

This comment has been minimized.

Copy link
@scop

scop Jul 29, 2022

Author Contributor

No, we just reset it to -1 for unknown functions, and handle as the error case in the if i == -1 below. Granted, this should never happen, as we only ever rule.Add the same ones.

}
}
if i == -1 {
return nil, fmt.Errorf("SQL argument index not found for %s.%s", typeName, fnName)
}
if i >= len(call.Args) {

This comment has been minimized.

Copy link
@ccojocar

ccojocar Jul 29, 2022

Member

This will never happen, you assign always -1 to variable i.

This comment has been minimized.

Copy link
@scop

scop Jul 29, 2022

Author Contributor

It happens if we have an argument index in sqlCallIdents that is larger than or equal to the number of arguments in code we see for that function. For example, if we had 100 for Exec there. This is quite theoretical given stdlib stability, and such code likely wouldn't compile and thus we might never hit this in real life, but it's good to handle it for completeness.

return nil, nil
}
query := call.Args[i]
return query, nil
}

func (s *sqlStatement) ID() string {
return s.MetaData.ID
}
Expand Down Expand Up @@ -69,16 +114,10 @@ func (s *sqlStrConcat) checkObject(n *ast.Ident, c *gosec.Context) bool {

// checkQuery verifies if the query parameters is a string concatenation
func (s *sqlStrConcat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gosec.Issue, error) {
_, fnName, err := gosec.GetCallInfo(call, ctx)
query, err := findQueryArg(call, ctx)
if err != nil {
return nil, err
}
var query ast.Node
if strings.HasSuffix(fnName, "Context") {
query = call.Args[1]
} else {
query = call.Args[0]
}

if be, ok := query.(*ast.BinaryExpr); ok {
operands := gosec.GetBinaryExprOperands(be)
Expand Down Expand Up @@ -137,8 +176,11 @@ func NewSQLStrConcat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
},
}

rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext")
rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext")
for s, si := range sqlCallIdents {
for i := range si {
rule.Add(s, i)
}
}
return rule, []ast.Node{(*ast.AssignStmt)(nil), (*ast.ExprStmt)(nil)}
}

Expand Down Expand Up @@ -171,16 +213,10 @@ func (s *sqlStrFormat) constObject(e ast.Expr, c *gosec.Context) bool {
}

func (s *sqlStrFormat) checkQuery(call *ast.CallExpr, ctx *gosec.Context) (*gosec.Issue, error) {
_, fnName, err := gosec.GetCallInfo(call, ctx)
query, err := findQueryArg(call, ctx)
if err != nil {
return nil, err
}
var query ast.Node
if strings.HasSuffix(fnName, "Context") {
query = call.Args[1]
} else {
query = call.Args[0]
}

if ident, ok := query.(*ast.Ident); ok && ident.Obj != nil {
decl := ident.Obj.Decl
Expand Down Expand Up @@ -306,8 +342,11 @@ func NewSQLStrFormat(id string, conf gosec.Config) (gosec.Rule, []ast.Node) {
},
},
}
rule.AddAll("*database/sql.DB", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext")
rule.AddAll("*database/sql.Tx", "Query", "QueryContext", "QueryRow", "QueryRowContext", "Exec", "ExecContext", "Prepare", "PrepareContext")
for s, si := range sqlCallIdents {
for i := range si {
rule.Add(s, i)
}
}
rule.fmtCalls.AddAll("fmt", "Sprint", "Sprintf", "Sprintln", "Fprintf")
rule.noIssue.AddAll("os", "Stdout", "Stderr")
rule.noIssueQuoted.Add("github.com/lib/pq", "QuoteIdentifier")
Expand Down

0 comments on commit a18c3a0

Please sign in to comment.