Skip to content

Commit

Permalink
bugfix: "select by field" don't work when permissions is not set (#50)
Browse files Browse the repository at this point in the history
* bugfix: "select by field" don't work since permissions are applied on it

Signed-off-by: julien brunet <julien@brunet.io>

* Support Returning function in query

* refactor

* cleaning, test for another request, pushed by mistake

* cleaning, some add was pushed by mistake

* Add support to RETURNING clause

* convert decimal or numeric postgres type to string

* adding ReturningByRequest mock

* adding returning clause on Delete
  • Loading branch information
xulien authored and felipeweb committed Aug 22, 2018
1 parent c656d56 commit 94d74c9
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 5 deletions.
1 change: 1 addition & 0 deletions adapter.go
Expand Up @@ -30,6 +30,7 @@ type Adapter interface {
ParseScript(scriptPath string, queryURL url.Values) (sqlQuery string, values []interface{}, err error)
Query(SQL string, params ...interface{}) (sc Scanner)
QueryCount(SQL string, params ...interface{}) (sc Scanner)
ReturningByRequest(r *http.Request) (returningSyntax string, err error)
SchemaClause(req *http.Request) (query string, hasCount bool)
SchemaOrderBy(order string, hasCount bool) (orderBy string)
SchemaTablesClause() (query string)
Expand Down
5 changes: 5 additions & 0 deletions mock/mock.go
Expand Up @@ -98,6 +98,11 @@ func (m *Mock) WhereByRequest(r *http.Request, initialPlaceholderID int) (whereS
return
}

// ReturningByRequest mock
func (m *Mock) ReturningByRequest(r *http.Request) (ReturningSyntax string, err error) {
return
}

// DatabaseClause mock
func (m *Mock) DatabaseClause(req *http.Request) (query string, hasCount bool) {
m.t.Helper()
Expand Down
92 changes: 87 additions & 5 deletions postgres/postgres.go
Expand Up @@ -249,6 +249,20 @@ func (adapter *Postgres) WhereByRequest(r *http.Request, initialPlaceholderID in
return
}

// ReturningByRequest create interface for queries + returning
func (adapter *Postgres) ReturningByRequest(r *http.Request) (returningSyntax string, err error) {
queries := r.URL.Query()["_returning"]
if len(queries) > 0 {
for i, q := range queries {
if i > 0 && i < len(queries) {
returningSyntax += ", "
}
returningSyntax += q
}
}
return
}

// SetByRequest create a set clause for SQL
func (adapter *Postgres) SetByRequest(r *http.Request, initialPlaceholderID int) (setSyntax string, values []interface{}, err error) {
body := make(map[string]interface{})
Expand Down Expand Up @@ -802,6 +816,38 @@ func (adapter *Postgres) Delete(SQL string, params ...interface{}) (sc adapters.
sc = &scanner.PrestScanner{Error: err}
return
}
if strings.Contains(SQL, "RETURNING") {
rows, _ := stmt.Query(params...)
cols, _ := rows.Columns()
var data []map[string]interface{}
for rows.Next() {
columns := make([]interface{}, len(cols))
columnPointers := make([]interface{}, len(cols))
for i := range columns {
columnPointers[i] = &columns[i]
}
if err := rows.Scan(columnPointers...); err != nil {
log.Fatal(err)
}
m := make(map[string]interface{})
for i, colName := range cols {
val := columnPointers[i].(*interface{})
switch (*val).(type) {
case []uint8:
m[colName] = string((*val).([]byte))
default:
m[colName] = *val
}
}
data = append(data, m)
}
jsonData, _ := json.Marshal(data)
sc = &scanner.PrestScanner{
Error: err,
Buff: bytes.NewBuffer(jsonData),
}
return
}
var result sql.Result
var rowsAffected int64
result, err = stmt.Exec(params...)
Expand Down Expand Up @@ -840,6 +886,38 @@ func (adapter *Postgres) Update(SQL string, params ...interface{}) (sc adapters.
return
}
log.Debugln("generated SQL:", SQL, " parameters: ", params)
if strings.Contains(SQL, "RETURNING") {
rows, _ := stmt.Query(params...)
cols, _ := rows.Columns()
var data []map[string]interface{}
for rows.Next() {
columns := make([]interface{}, len(cols))
columnPointers := make([]interface{}, len(cols))
for i := range columns {
columnPointers[i] = &columns[i]
}
if err := rows.Scan(columnPointers...); err != nil {
log.Fatal(err)
}
m := make(map[string]interface{})
for i, colName := range cols {
val := columnPointers[i].(*interface{})
switch (*val).(type) {
case []uint8:
m[colName] = string((*val).([]byte))
default:
m[colName] = *val
}
}
data = append(data, m)
}
jsonData, _ := json.Marshal(data)
sc = &scanner.PrestScanner{
Error: err,
Buff: bytes.NewBuffer(jsonData),
}
return
}
var result sql.Result
var rowsAffected int64
result, err = stmt.Exec(params...)
Expand Down Expand Up @@ -969,16 +1047,20 @@ func intersection(set, other []string) (intersection []string) {

// FieldsPermissions get fields permissions based in prest configuration
func (adapter *Postgres) FieldsPermissions(r *http.Request, table string, op string) (fields []string, err error) {
restrict := config.PrestConf.AccessConf.Restrict
if !restrict || op == "delete" {
fields = []string{"*"}
return
}
cols, err := columnsByRequest(r)
if err != nil {
err = fmt.Errorf("error on parse columns from request: %s", err)
return
}
restrict := config.PrestConf.AccessConf.Restrict
if !restrict || op == "delete" {
if len(cols) > 0 {
fields = cols
return
}
fields = []string{"*"}
return
}
allowedFields := fieldsByPermission(table, op)
if len(allowedFields) == 0 {
err = errors.New("there's no configured field for this table")
Expand Down
31 changes: 31 additions & 0 deletions postgres/postgres_test.go
Expand Up @@ -231,6 +231,37 @@ func TestInvalidWhereByRequest(t *testing.T) {
}
}

func TestReturningByRequest(t *testing.T) {
var testCases = []struct {
description string
url string
expectedSQL []string
err error
}{
{"Returning by request with nothing", "/prest/public/test_group_by_table", []string{""}, nil},
{"Returning by request with _returning=*", "/prest/public/test_group_by_table?_returning=*", []string{"RETURNING *"}, nil},
{"Returning by request with _returning=field", "/prest/public/test_group_by_table?_returning=age", []string{"RETURNING age"}, nil},
{"Returning by request with multiple _returning=field", "/prest/public/test_group_by_table?_returning=age&_returning=salary", []string{"RETURNING age,salary"}, nil},
}
for _, tc := range testCases {
t.Log(tc.description)
req, err := http.NewRequest("GET", tc.url, nil)
if err != nil {
t.Errorf("expected no errors in http request, got %v", err)
}
returning, err := config.PrestConf.Adapter.ReturningByRequest(req)
t.Log("returning:", returning)
if err != nil {
t.Errorf("expected no errors in returning by request, got %v", err)
}
for _, sql := range tc.expectedSQL {
if !strings.Contains(returning, sql) {
t.Errorf("expected %s in %s, but not was!", sql, returning)
}
}
}
}

func TestGroupByClause(t *testing.T) {
var testCases = []struct {
description string
Expand Down

0 comments on commit 94d74c9

Please sign in to comment.