diff --git a/postgres/postgres.go b/postgres/postgres.go index c3f22b6be..083fe713d 100644 --- a/postgres/postgres.go +++ b/postgres/postgres.go @@ -932,40 +932,68 @@ func (adapter *Postgres) TablePermissions(table string, op string) bool { return false } -// FieldsPermissions get fields permissions based in prest configuration -func (adapter *Postgres) FieldsPermissions(r *http.Request, table string, op string) (fields []string, err error) { - restrict := config.PrestConf.AccessConf.Restrict - cols := columnsByRequest(r) - queries := r.URL.Query() - if queries.Get("_groupby") != "" { - cols, err = normalizeAll(cols) - if err != nil { - return - } - } - if !restrict { - fields = cols - return - } - +func fieldsByPermission(table, op string) (fields []string) { tables := config.PrestConf.AccessConf.Tables for _, t := range tables { if t.Name == table { - for _, col := range cols { - // return all permitted fields if have "*" in SELECT - if op == "read" && col == "*" { + for _, perm := range t.Permissions { + if perm == op { fields = t.Fields - return - } - pField := checkField(col, t.Fields) - if pField != "" { - fields = append(fields, pField) } } - return } } - return nil, errors.New("0 tables configured") + return +} + +func containsAsterisk(arr []string) bool { + for _, e := range arr { + if e == "*" { + return true + } + } + return false +} + +func intersection(set, other []string) (intersection []string) { + for _, field := range set { + pField := checkField(field, other) + if pField != "" { + intersection = append(intersection, pField) + } + } + return +} + +// FieldsPermissions get fields permissions based in prest configuration +func (adapter *Postgres) FieldsPermissions(r *http.Request, table string, op string) (fields []string, err error) { + restrict := config.PrestConf.AccessConf.Restrict + if !restrict || op == "delete" { + fields = []string{"*"} + return + } + cols, err := columnsByRequest(r) + if err != nil { + err = fmt.Errorf("error on parse columns from request: %s", err) + return + } + allowedFields := fieldsByPermission(table, op) + if len(allowedFields) == 0 { + err = errors.New("there's no configured field for this table") + return + } + if containsAsterisk(allowedFields) { + fields = []string{"*"} + if len(cols) > 0 { + fields = cols + } + return + } + fields = intersection(cols, allowedFields) + if len(cols) == 0 { + fields = allowedFields + } + return } func checkField(col string, fields []string) (p string) { @@ -1006,23 +1034,22 @@ func normalizeColumn(col string) (gf string, err error) { } // columnsByRequest extract columns and return as array of strings -func columnsByRequest(r *http.Request) []string { - u, _ := r.URL.Parse(r.URL.String()) - columnsArr := u.Query()["_select"] - var columns []string - +func columnsByRequest(r *http.Request) (columns []string, err error) { + queries := r.URL.Query() + columnsArr := queries["_select"] for _, j := range columnsArr { cArgs := strings.Split(j, ",") for _, columnName := range cArgs { - if len(columnName) > 0 { - columns = append(columns, columnName) - } + columns = append(columns, columnName) } } - if len(columns) == 0 { - return []string{"*"} + if queries.Get("_groupby") != "" { + columns, err = normalizeAll(columns) + if err != nil { + return + } } - return columns + return } // DistinctClause get params in request to add distinct clause diff --git a/postgres/postgres_test.go b/postgres/postgres_test.go index 12a399e4a..f95b26900 100644 --- a/postgres/postgres_test.go +++ b/postgres/postgres_test.go @@ -852,39 +852,6 @@ func TestTablePermissions(t *testing.T) { } -func TestFieldsPermissions(t *testing.T) { - var testCases = []struct { - description string - url string - table string - permission string - resultLen int - }{ - {"Read valid field", "/prest/public/test_list_only_id?_select=id", "test_list_only_id", "read", 1}, - {"Read invalid field", "/prest/public/test_list_only_id?_select=name", "test_list_only_id", "read", 0}, - {"Read non existing field", "/prest/public/test_list_only_id?_select=non_existing_field", "test_list_only_id", "read", 0}, - {"Select with *", "/prest/public/test_list_only_id?_select=*", "test_list_only_id", "read", 1}, - {"Select with group function", "/prest/public/test_group_by_table?_select=age,sum:salary&_groupby=age", "test_group_by_table", "read", 2}, - } - - for _, tc := range testCases { - t.Log(tc.description) - - r, err := http.NewRequest("GET", tc.url, nil) - if err != nil { - t.Errorf("expected no errors on NewRequest, but got: %v", err) - } - - fields, err := config.PrestConf.Adapter.FieldsPermissions(r, tc.table, tc.permission) - if err != nil { - t.Errorf("expected no errors, but got %v", err) - } - if len(fields) != tc.resultLen { - t.Errorf("expected %d in table: %s, got: %d - %v", tc.resultLen, tc.table, len(fields), fields) - } - } -} - func TestRestrictFalse(t *testing.T) { config.PrestConf.AccessConf.Restrict = false @@ -963,17 +930,17 @@ func TestColumnsByRequest(t *testing.T) { {"Select array field from table", "/prest/public/testarray?_select=data", "data"}, {"Select fields from table", "/prest/public/test5?_select=celphone", "celphone"}, {"Select all from table", "/prest/public/test5?_select=*", "*"}, - {"Select with empty '_select' field", "/prest/public/test5?_select=", "*"}, + {"Select with empty '_select' field", "/prest/public/test5?_select=", ""}, {"Select with more columns", "/prest/public/test5?_select=celphone,battery", "celphone,battery"}, + {"Select with more columns", "/prest/public/test5?_select=age,sum:salary&_groupby=age", `age,SUM("salary")`}, } for _, tc := range testCases { - t.Log(tc.description) r, err := http.NewRequest("GET", tc.url, nil) if err != nil { t.Errorf("expected no errors on NewRequest, but got: %v", err) } - selectQuery := columnsByRequest(r) + selectQuery, _ := columnsByRequest(r) selectStr := strings.Join(selectQuery, ",") if selectStr != tc.expectedSQL { t.Errorf("expected %s, got: %s", tc.expectedSQL, selectStr) @@ -1270,3 +1237,226 @@ func TestPostgres_BatchInsertCopy(t *testing.T) { }) } } + +func TestPostgres_FieldsPermissions(t *testing.T) { + type args struct { + url string + table string + op string + fields []string + } + tests := []struct { + name string + args args + restrict bool + wantFields []string + wantErr bool + }{ + { + name: "delete operations always returns *", + args: args{ + op: "delete", + }, + wantFields: []string{"*"}, + }, + { + name: "if restrict is false returns *", + wantFields: []string{"*"}, + }, + { + name: "error on parse groupby request", + args: args{ + url: "/table_field_permission?_select=fail:fail&_groupby=fail", + }, + restrict: true, + wantErr: true, + }, + { + name: "error with no allowed fields", + args: args{ + url: "/table_field_permission", + }, + restrict: true, + wantErr: true, + }, + { + name: "allowed fields contains * and user don't pass select", + args: args{ + url: "/table_field_permission", + table: "test_field_permission", + op: "write", + fields: []string{"*"}, + }, + restrict: true, + wantErr: false, + wantFields: []string{"*"}, + }, + { + name: "allowed fields contains * and user ask for only only field", + args: args{ + url: "/table_field_permission?_select=name", + table: "test_field_permission", + op: "write", + fields: []string{"*"}, + }, + restrict: true, + wantErr: false, + wantFields: []string{"name"}, + }, + { + name: "allowed fields contains * and user ask for multiple fields", + args: args{ + url: "/table_field_permission?_select=name,age", + table: "test_field_permission", + op: "write", + fields: []string{"*"}, + }, + restrict: true, + wantErr: false, + wantFields: []string{"name", "age"}, + }, + { + name: "user ask for allowed field", + args: args{ + url: "/table_field_permission?_select=name", + table: "test_field_permission", + op: "write", + fields: []string{"name", "age"}, + }, + restrict: true, + wantErr: false, + wantFields: []string{"name"}, + }, + { + name: "user ask for not allowed field", + args: args{ + url: "/table_field_permission?_select=id", + table: "test_field_permission", + op: "write", + fields: []string{"name", "age"}, + }, + restrict: true, + wantErr: false, + }, + { + name: "allowed some fields but user ask for nothing", + args: args{ + url: "/table_field_permission", + table: "test_field_permission", + op: "write", + fields: []string{"name", "age"}, + }, + restrict: true, + wantErr: false, + wantFields: []string{"name", "age"}, + }, + { + name: "functions in select should respect table permissions", + args: args{ + url: "/table_field_permission?_groupby=number&_select=max:number", + table: "test_field_permission", + op: "write", + fields: []string{"name", "age"}, + }, + restrict: true, + wantErr: false, + }, + { + name: "select with function and allowed field returns field", + args: args{ + url: "/table_field_permission?_groupby=age&_select=max:age", + table: "test_field_permission", + op: "write", + fields: []string{"name", "age"}, + }, + restrict: true, + wantErr: false, + wantFields: []string{`MAX("age")`}, + }, + } + for _, tt := range tests { + config.PrestConf.AccessConf.Restrict = tt.restrict + config.PrestConf.AccessConf.Tables = []config.TablesConf{} + config.PrestConf.AccessConf.Tables = append(config.PrestConf.AccessConf.Tables, + config.TablesConf{ + Name: "test_field_permission", + Permissions: []string{"read", "write", "delete"}, + Fields: tt.args.fields, + }) + t.Run(tt.name, func(t *testing.T) { + adapter := &Postgres{} + r, err := http.NewRequest(http.MethodGet, tt.args.url, strings.NewReader("")) + if err != nil { + t.Fatal(err) + } + gotFields, err := adapter.FieldsPermissions(r, tt.args.table, tt.args.op) + if (err != nil) != tt.wantErr { + t.Errorf("Postgres.FieldsPermissions() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(gotFields, tt.wantFields) { + t.Errorf("Postgres.FieldsPermissions() = %v, want %v", gotFields, tt.wantFields) + } + }) + } +} + +func Test_intersection(t *testing.T) { + type args struct { + set []string + other []string + } + tests := []struct { + name string + args args + wantInter []string + }{ + {name: "two empty sets returns empty", wantInter: nil}, + {name: "intersection with empty set returns empty set", args: args{set: []string{"name"}, other: []string{}}, + wantInter: nil}, + {name: "intersection of empty set with other returns empty set", args: args{set: []string{}, other: []string{"name"}}, + wantInter: nil}, + {name: "intersection of two sets", args: args{set: []string{"name", "age"}, other: []string{"name"}}, + wantInter: []string{"name"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotInter := intersection(tt.args.set, tt.args.other); !reflect.DeepEqual(gotInter, tt.wantInter) { + t.Errorf("intersection() = %v, want %v", gotInter, tt.wantInter) + } + }) + } +} + +func Test_containsAsterisk(t *testing.T) { + type args struct { + arr []string + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "contains *", + args: args{ + arr: []string{"*"}, + }, + want: true, + }, + { + name: "dont contains *", + args: args{ + arr: []string{}, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := containsAsterisk(tt.args.arr); got != tt.want { + t.Errorf("containsAsterisk() = %v, want %v", got, tt.want) + } + }) + } +}