Skip to content
This repository has been archived by the owner on Sep 21, 2021. It is now read-only.

Add gorm and sqlx support, make easy add new other ORM #10

Merged
merged 4 commits into from
Dec 21, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ How does it work?
-----------------

SafeSQL uses the static analysis utilities in [go/tools][tools] to search for
all call sites of each of the `query` functions in package [database/sql][sql]
(i.e., functions which accept a `string` parameter named `query`). It then makes
all call sites of each of the `query` functions in packages ([database/sql][sql],[github.com/jinzhu/gorm][gorm],[github.com/jmoiron/sqlx][sqlx])
(i.e., functions which accept a parameter named `query`,`sql`). It then makes
sure that every such call site uses a query that is a compile-time constant.

The principle behind SafeSQL's safety guarantees is that queries that are
Expand All @@ -44,6 +44,8 @@ will not be allowed.

[tools]: https://godoc.org/golang.org/x/tools/go
[sql]: http://golang.org/pkg/database/sql/
[sqlx]: https://github.com/jmoiron/sqlx
[gorm]: https://github.com/jinzhu/gorm

False positives
---------------
Expand All @@ -66,8 +68,6 @@ a fundamental limitation: SafeSQL could recursively trace the `query` argument
through every intervening helper function to ensure that its argument is always
constant, but this code has yet to be written.

If you use a wrapper for `database/sql` (e.g., [`sqlx`][sqlx]), it's likely
SafeSQL will not work for you because of this.

The second sort of false positive is based on a limitation in the sort of
analysis SafeSQL performs: there are many safe SQL statements which are not
Expand All @@ -76,4 +76,3 @@ static analysis techniques (such as taint analysis) or user-provided safety
annotations would be able to reduce the number of false positives, but this is
expected to be a significant undertaking.

[sqlx]: https://github.com/jmoiron/sqlx
109 changes: 95 additions & 14 deletions safesql.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"go/build"
"go/types"
"os"

"path/filepath"
"strings"

Expand All @@ -19,6 +20,27 @@ import (
"golang.org/x/tools/go/ssa/ssautil"
)

type sqlPackage struct {
packageName string
paramNames []string
enable bool
}

var sqlPackages = []sqlPackage{
{
packageName: "database/sql",
paramNames: []string{"query"},
},
{
packageName: "github.com/jinzhu/gorm",
paramNames: []string{"sql", "query"},
},
{
packageName: "github.com/jmoiron/sqlx",
paramNames: []string{"query"},
},
}

func main() {
var verbose, quiet bool
flag.BoolVar(&verbose, "v", false, "Verbose mode")
Expand All @@ -38,21 +60,45 @@ func main() {
c := loader.Config{
FindPackage: FindPackage,
}
c.Import("database/sql")
for _, pkg := range pkgs {
c.Import(pkg)
}
p, err := c.Load()

if err != nil {
fmt.Printf("error loading packages %v: %v\n", pkgs, err)
os.Exit(2)
}

imports := getImports(p)
existOne := false
for i := range sqlPackages {
if _, exist := imports[sqlPackages[i].packageName]; exist {
if verbose {
fmt.Printf("Enabling support for %s\n", sqlPackages[i].packageName)
}
sqlPackages[i].enable = true
existOne = true
}
}
if !existOne {
fmt.Printf("No packages in %v include a supported database driver", pkgs)
os.Exit(2)
}

s := ssautil.CreateProgram(p, 0)
s.Build()

qms := FindQueryMethods(p.Package("database/sql").Pkg, s)
qms := make([]*QueryMethod, 0)

for i := range sqlPackages {
if sqlPackages[i].enable {
qms = append(qms, FindQueryMethods(sqlPackages[i], p.Package(sqlPackages[i].packageName).Pkg, s)...)
}
}

if verbose {
fmt.Println("database/sql functions that accept queries:")
fmt.Println("database driver functions that accept queries:")
for _, m := range qms {
fmt.Printf("- %s (param %d)\n", m.Func, m.Param)
}
Expand All @@ -75,21 +121,27 @@ func main() {
}

bad := FindNonConstCalls(res.CallGraph, qms)

if len(bad) == 0 {
if !quiet {
fmt.Println(`You're safe from SQL injection! Yay \o/`)
}
return
}

fmt.Printf("Found %d potentially unsafe SQL statements:\n", len(bad))
if verbose {
fmt.Printf("Found %d potentially unsafe SQL statements:\n", len(bad))
}

for _, ci := range bad {
pos := p.Fset.Position(ci.Pos())
fmt.Printf("- %s\n", pos)
}
fmt.Println("Please ensure that all SQL queries you use are compile-time constants.")
fmt.Println("You should always use parameterized queries or prepared statements")
fmt.Println("instead of building queries from strings.")
if verbose {
fmt.Println("Please ensure that all SQL queries you use are compile-time constants.")
fmt.Println("You should always use parameterized queries or prepared statements")
fmt.Println("instead of building queries from strings.")
}
os.Exit(1)
}

Expand All @@ -104,7 +156,7 @@ type QueryMethod struct {

// FindQueryMethods locates all methods in the given package (assumed to be
// package database/sql) with a string parameter named "query".
func FindQueryMethods(sql *types.Package, ssa *ssa.Program) []*QueryMethod {
func FindQueryMethods(sqlPackages sqlPackage, sql *types.Package, ssa *ssa.Program) []*QueryMethod {
methods := make([]*QueryMethod, 0)
scope := sql.Scope()
for _, name := range scope.Names() {
Expand All @@ -122,7 +174,7 @@ func FindQueryMethods(sql *types.Package, ssa *ssa.Program) []*QueryMethod {
continue
}
s := m.Type().(*types.Signature)
if num, ok := FuncHasQuery(s); ok {
if num, ok := FuncHasQuery(sqlPackages, s); ok {
methods = append(methods, &QueryMethod{
Func: m,
SSA: ssa.FuncValue(m),
Expand All @@ -135,16 +187,16 @@ func FindQueryMethods(sql *types.Package, ssa *ssa.Program) []*QueryMethod {
return methods
}

var stringType types.Type = types.Typ[types.String]

// FuncHasQuery returns the offset of the string parameter named "query", or
// none if no such parameter exists.
func FuncHasQuery(s *types.Signature) (offset int, ok bool) {
func FuncHasQuery(sqlPackages sqlPackage, s *types.Signature) (offset int, ok bool) {
params := s.Params()
for i := 0; i < params.Len(); i++ {
v := params.At(i)
if v.Name() == "query" && v.Type() == stringType {
return i, true
for _, paramName := range sqlPackages.paramNames {
if v.Name() == paramName {
return i, true
}
}
}
return 0, false
Expand All @@ -164,6 +216,16 @@ func FindMains(p *loader.Program, s *ssa.Program) []*ssa.Package {
return mains
}

func getImports(p *loader.Program) map[string]interface{} {
pkgs := make(map[string]interface{})
for _, pkg := range p.AllPackages {
if pkg.Importable {
pkgs[pkg.Pkg.Path()] = nil
}
}
return pkgs
}

// FindNonConstCalls returns the set of callsites of the given set of methods
// for which the "query" parameter is not a compile-time constant.
func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstruction {
Expand All @@ -186,6 +248,18 @@ func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstru
if _, ok := okFuncs[edge.Site.Parent()]; ok {
continue
}

isInternalSQLPkg := false
for _, pkg := range sqlPackages {
if pkg.packageName == edge.Caller.Func.Pkg.Pkg.Path() {
isInternalSQLPkg = true
break
}
}
if isInternalSQLPkg {
continue
}

cc := edge.Site.Common()
args := cc.Args
// The first parameter is occasionally the receiver.
Expand All @@ -195,7 +269,14 @@ func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstru
panic("arg count mismatch")
}
v := args[m.Param]

if _, ok := v.(*ssa.Const); !ok {
if inter, ok := v.(*ssa.MakeInterface); ok && types.IsInterface(v.(*ssa.MakeInterface).Type()) {
if inter.X.Referrers() == nil || inter.X.Type() != types.Typ[types.String] {
continue
}
}

bad = append(bad, edge.Site)
}
}
Expand Down