Skip to content

Commit

Permalink
Merge pull request #40 from nuveo/crg_prevent_sql_injection
Browse files Browse the repository at this point in the history
Prevent SQL injection closes #36
  • Loading branch information
felipeweb committed Dec 14, 2016
2 parents b4826b9 + 0fbb68c commit ae3db8f
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 77 deletions.
123 changes: 103 additions & 20 deletions adapters/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package postgres

import (
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"unicode"

"github.com/caarlos0/env"
"github.com/jmoiron/sqlx"
Expand All @@ -21,7 +24,7 @@ import (
const (
pageNumberKey = "_page"
pageSizeKey = "_page_size"
defaultPageSize = "10"
defaultPageSize = 10
)

// Conn connect on PostgreSQL
Expand All @@ -37,44 +40,100 @@ func Conn() (db *sqlx.DB) {
return
}

// chkInvaidIdentifier return true if identifier is invalid
func chkInvaidIdentifier(identifer string) bool {
if len(identifer) > 63 ||
unicode.IsDigit([]rune(identifer)[0]) {
return true
}

for _, v := range identifer {
if !unicode.IsLetter(v) &&
!unicode.IsDigit(v) &&
v != '_' &&
v != '.' {
return true
}
}
return false
}

// WhereByRequest create interface for queries + where
func WhereByRequest(r *http.Request) (whereSyntax string) {
func WhereByRequest(r *http.Request, initialPlaceholderID int) (whereSyntax string, values []interface{}, err error) {
//whereMap := make(map[string]string)
whereKey := []string{}
whereValues := []string{}

u, _ := url.Parse(r.URL.String())
where := []string{}
pid := initialPlaceholderID
for key, val := range u.Query() {
if !strings.HasPrefix(key, "_") {
keyInfo := strings.Split(key, ":")
if len(keyInfo) > 1 {
switch keyInfo[1] {
case "jsonb":
jsonField := strings.Split(keyInfo[0], "->>")
where = append(where, fmt.Sprintf("%s->>'%s'='%s'", jsonField[0], jsonField[1], val[0]))
if chkInvaidIdentifier(jsonField[0]) ||
chkInvaidIdentifier(jsonField[1]) {
err = errors.New("Invalid identifier")
return
}
whereKey = append(whereKey, fmt.Sprintf("%s->>'%s'=$%d", jsonField[0], jsonField[1], pid))
whereValues = append(whereValues, val[0])
default:
where = append(where, fmt.Sprintf("%s='%s'", keyInfo[0], val[0]))

if chkInvaidIdentifier(keyInfo[0]) {
err = errors.New("Invalid identifier")
return
}
whereKey = append(whereKey, fmt.Sprintf("%s=$%d", keyInfo[0], pid))
whereValues = append(whereValues, val[0])
}
continue
}
where = append(where, fmt.Sprintf("%s='%s'", key, val[0]))
if chkInvaidIdentifier(key) {
err = errors.New("Invalid identifier")
return
}

whereKey = append(whereKey, fmt.Sprintf("%s=$%d", key, pid))
whereValues = append(whereValues, val[0])

pid++
}
}

whereSyntax = strings.Join(where, " and ")
for i := 0; i < len(whereKey); i++ {
if whereSyntax == "" {
whereSyntax += whereKey[i]
} else {
whereSyntax += " AND " + whereKey[i]
}

values = append(values, whereValues[i])
}

return
}

// Query process queries
func Query(SQL string, params ...interface{}) (jsonData []byte, err error) {
db := Conn()
rows, err := db.Queryx(SQL, params...)

prepare, err := db.Prepare(SQL)

if err != nil {
return
}

rows, err := prepare.Query(params...)
if err != nil {
return nil, err
return
}
defer rows.Close()

columns, err := rows.Columns()
if err != nil {
return nil, err
return
}

count := len(columns)
Expand Down Expand Up @@ -106,19 +165,25 @@ func Query(SQL string, params ...interface{}) (jsonData []byte, err error) {
}

// PaginateIfPossible func
func PaginateIfPossible(r *http.Request) (paginatedQuery string) {
func PaginateIfPossible(r *http.Request) (paginatedQuery string, err error) {
u, _ := url.Parse(r.URL.String())
values := u.Query()
if _, ok := values[pageNumberKey]; !ok {
paginatedQuery = ""
return
}
pageNumber := values[pageNumberKey][0]
pageNumber, err := strconv.Atoi(values[pageNumberKey][0])
if err != nil {
return
}
pageSize := defaultPageSize
if size, ok := values[pageSizeKey]; ok {
pageSize = size[0]
pageSize, err = strconv.Atoi(size[0])
if err != nil {
return
}
}
paginatedQuery = fmt.Sprintf("LIMIT %s OFFSET(%s - 1) * %s", pageSize, pageNumber, pageSize)
paginatedQuery = fmt.Sprintf("LIMIT %d OFFSET(%d - 1) * %d", pageSize, pageNumber, pageSize)
return
}

Expand Down Expand Up @@ -154,7 +219,7 @@ func Insert(database, schema, table string, body api.Request) (jsonData []byte,
}

// Delete execute delete sql into a table
func Delete(database, schema, table, where string) (jsonData []byte, err error) {
func Delete(database, schema, table, where string, whereValues []interface{}) (jsonData []byte, err error) {
var result sql.Result
var rowsAffected int64

Expand All @@ -167,7 +232,7 @@ func Delete(database, schema, table, where string) (jsonData []byte, err error)
}

db := Conn()
result, err = db.Exec(sql)
result, err = db.Exec(sql, whereValues...)
if err != nil {
return
}
Expand All @@ -183,13 +248,17 @@ func Delete(database, schema, table, where string) (jsonData []byte, err error)
}

// Update execute update sql into a table
func Update(database, schema, table, where string, body api.Request) (jsonData []byte, err error) {
func Update(database, schema, table, where string, whereValues []interface{}, body api.Request) (jsonData []byte, err error) {
var result sql.Result
var rowsAffected int64

fields := []string{}
values := make([]interface{}, 0)
pid := len(whereValues) + 1 // placeholder id
for key, value := range body.Data {
fields = append(fields, fmt.Sprintf("%s='%s'", key, value))
fields = append(fields, fmt.Sprintf("%s=$%d", key, pid))
values = append(values, value)
pid++
}
setSyntax := strings.Join(fields, ", ")

Expand All @@ -200,13 +269,27 @@ func Update(database, schema, table, where string, body api.Request) (jsonData [
sql,
" WHERE ",
where)
values = append(values, whereValues...)
}

db := Conn()
result, err = db.Exec(sql)
//result, err = db.Exec(sql, values)
stmt, err := db.Prepare(sql)
if err != nil {
return
}

valuesAux := make([]interface{}, 0, len(values))

for i := 0; i < len(values); i++ {
valuesAux = append(valuesAux, values[i])
}

result, err = stmt.Exec(valuesAux...)
if err != nil {
return
}

rowsAffected, err = result.RowsAffected()
if err != nil {
return
Expand Down
31 changes: 20 additions & 11 deletions adapters/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,27 @@ func TestWhereByRequest(t *testing.T) {
Convey("Where by request without paginate", t, func() {
r, err := http.NewRequest("GET", "/databases?dbname=prest&test=cool", nil)
So(err, ShouldBeNil)
where := WhereByRequest(r)
So(where, ShouldContainSubstring, "dbname='prest'")
So(where, ShouldContainSubstring, "test='cool'")
So(where, ShouldContainSubstring, "and")

where, values, err := WhereByRequest(r, 1)
So(err, ShouldBeNil)
So(where, ShouldContainSubstring, "dbname=$")
So(where, ShouldContainSubstring, "test=$")
So(where, ShouldContainSubstring, " AND ")
So(values, ShouldContain, "prest")
So(values, ShouldContain, "cool")
})

Convey("Where by request with jsonb field", t, func() {
r, err := http.NewRequest("GET", "/prest/public/test?name=nuveo&data->>description:jsonb=bla", nil)
So(err, ShouldBeNil)
where := WhereByRequest(r)
So(where, ShouldContainSubstring, "name='nuveo'")
So(where, ShouldContainSubstring, "data->>'description'='bla'")
So(where, ShouldContainSubstring, "and")

where, values, err := WhereByRequest(r, 1)
So(err, ShouldBeNil)
So(where, ShouldContainSubstring, "name=$")
So(where, ShouldContainSubstring, "data->>'description'=$")
So(where, ShouldContainSubstring, " AND ")
So(values, ShouldContain, "nuveo")
So(values, ShouldContain, "bla")
})
}

Expand Down Expand Up @@ -57,7 +65,8 @@ func TestPaginateIfPossible(t *testing.T) {
Convey("Paginate if possible", t, func() {
r, err := http.NewRequest("GET", "/databases?dbname=prest&test=cool&_page=1&_page_size=20", nil)
So(err, ShouldBeNil)
where := PaginateIfPossible(r)
where, err := PaginateIfPossible(r)
So(err, ShouldBeNil)
So(where, ShouldContainSubstring, "LIMIT 20 OFFSET(1 - 1) * 20")
})
}
Expand All @@ -77,7 +86,7 @@ func TestInsert(t *testing.T) {

func TestDelete(t *testing.T) {
Convey("Delete data from table", t, func() {
json, err := Delete("prest", "public", "test", "name='nuveo'")
json, err := Delete("prest", "public", "test", "name=$1", []interface{}{"nuveo"})
So(err, ShouldBeNil)
So(len(json), ShouldBeGreaterThan, 0)
})
Expand All @@ -90,7 +99,7 @@ func TestUpdate(t *testing.T) {
"name": "prest",
},
}
json, err := Update("prest", "public", "test", "name='prest'", r)
json, err := Update("prest", "public", "test", "name=$1", []interface{}{"prest"}, r)
So(err, ShouldBeNil)
So(len(json), ShouldBeGreaterThan, 0)
})
Expand Down
20 changes: 17 additions & 3 deletions controllers/databases.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@ import (

// GetDatabases list all (or filter) databases
func GetDatabases(w http.ResponseWriter, r *http.Request) {
requestWhere := postgres.WhereByRequest(r)
requestWhere, values, err := postgres.WhereByRequest(r, 1)
if err != nil {
log.Println(err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

sqlDatabases := statements.Databases
if requestWhere != "" {
sqlDatabases = fmt.Sprint(
Expand All @@ -21,8 +27,16 @@ func GetDatabases(w http.ResponseWriter, r *http.Request) {
requestWhere,
statements.DatabasesOrderBy)
}
sqlDatabases = fmt.Sprint(sqlDatabases, " ", postgres.PaginateIfPossible(r))
object, err := postgres.Query(sqlDatabases)

page, err := postgres.PaginateIfPossible(r)
if err != nil {
http.Error(w, "Paging error", http.StatusBadRequest)
return
}

sqlDatabases = fmt.Sprint(sqlDatabases, " ", page)

object, err := postgres.Query(sqlDatabases, values...)
if err != nil {
log.Println(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down
6 changes: 3 additions & 3 deletions controllers/databases_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,20 @@ func TestGetDatabases(t *testing.T) {
r, err := http.NewRequest("GET", "/databases", nil)
w := httptest.NewRecorder()
So(err, ShouldBeNil)
validate(w, r, GetDatabases)
validate(w, r, GetDatabases, "TestGetDatabases")
})

Convey("Get databases with custom where clause", t, func() {
r, err := http.NewRequest("GET", "/databases?datname=prest", nil)
w := httptest.NewRecorder()
So(err, ShouldBeNil)
validate(w, r, GetDatabases)
validate(w, r, GetDatabases, "TestGetDatabases")
})

Convey("Get databases with custom where clause and pagination", t, func() {
r, err := http.NewRequest("GET", "/databases?datname=prest&_page=1&_page_size=20", nil)
w := httptest.NewRecorder()
So(err, ShouldBeNil)
validate(w, r, GetDatabases)
validate(w, r, GetDatabases, "TestGetDatabases")
})
}
19 changes: 16 additions & 3 deletions controllers/schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@ import (

// GetSchemas list all (or filter) schemas
func GetSchemas(w http.ResponseWriter, r *http.Request) {
requestWhere := postgres.WhereByRequest(r)
requestWhere, values, err := postgres.WhereByRequest(r, 1)
if err != nil {
log.Println(err)
http.Error(w, err.Error(), http.StatusBadRequest)
return
}

sqlSchemas := statements.Schemas
if requestWhere != "" {
sqlSchemas = fmt.Sprint(
Expand All @@ -20,8 +26,15 @@ func GetSchemas(w http.ResponseWriter, r *http.Request) {
requestWhere,
statements.SchemasOrderBy)
}
sqlSchemas = fmt.Sprint(sqlSchemas, " ", postgres.PaginateIfPossible(r))
object, err := postgres.Query(sqlSchemas)

page, err := postgres.PaginateIfPossible(r)
if err != nil {
http.Error(w, "Paging error", http.StatusBadRequest)
return
}

sqlSchemas = fmt.Sprint(sqlSchemas, " ", page)
object, err := postgres.Query(sqlSchemas, values...)
if err != nil {
log.Println(err)
http.Error(w, err.Error(), http.StatusInternalServerError)
Expand Down

0 comments on commit ae3db8f

Please sign in to comment.