Skip to content

Commit

Permalink
Merge 7cdf6fb into c58e30b
Browse files Browse the repository at this point in the history
  • Loading branch information
vpsx committed Sep 20, 2019
2 parents c58e30b + 7cdf6fb commit db16b60
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 40 deletions.
113 changes: 79 additions & 34 deletions arborist/policy.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package arborist

import (
"database/sql"
"encoding/json"
"fmt"
"strings"
Expand Down Expand Up @@ -34,13 +35,11 @@ func (policy *Policy) UnmarshalJSON(data []byte) error {
return err
}

// delete fields which should be ignored in user input
delete(fields, "ID")
// uncomment this after 3.0.0 and lowercase the ID field in optionalFields too
// delete(fields, "id")

// id is optional here because PUT doesn't require it to be in the json;
// handlePolicyOverwrite will populate id later, from the URL.
// id is still validated later, in policy `validate` function.
optionalFields := map[string]struct{}{
"ID": struct{}{},
"id": struct{}{},
"description": struct{}{},
}
err = validateJSON("policy", policy, fields, optionalFields)
Expand Down Expand Up @@ -169,6 +168,9 @@ func (policy *Policy) roles(tx *sqlx.Tx) ([]RoleFromQuery, error) {
// looking at the database. This includes that the policy must contain at least
// one resource and at least one role.
func (policy *Policy) validate() *ErrorResponse {
if len(policy.Name) == 0 {
return newErrorResponse("policy ID cannot be absent or empty", 400, nil)
}
// Resources and roles must be non-empty
if len(policy.ResourcePaths) == 0 {
return newErrorResponse("no resource paths specified", 400, nil)
Expand All @@ -179,26 +181,9 @@ func (policy *Policy) validate() *ErrorResponse {
return nil
}

// createInDb writes out the policy to the database.
func (policy *Policy) createInDb(tx *sqlx.Tx) *ErrorResponse {
errResponse := policy.validate()
if errResponse != nil {
return errResponse
}

var policyID int
// TODO: make sure description works as expected
stmt := "INSERT INTO policy(name, description) VALUES ($1, $2) RETURNING id"
row := tx.QueryRowx(stmt, policy.Name, policy.Description)
err := row.Scan(&policyID)
if err != nil {
// should add more checking here to guarantee the correct error
_ = tx.Rollback()
// this should only fail because the policy was not unique. return error
// accordingly
msg := fmt.Sprintf("failed to insert policy: policy with this ID already exists: %s", policy.Name)
return newErrorResponse(msg, 409, &err)
}
// addResourcesAndRoles takes a policy and links it in the database
// to each of its resources and roles.
func (policy *Policy) addResourcesAndRoles(tx *sqlx.Tx, policyID int) *ErrorResponse {

// `resources` is a list of looked-up resources which appear in the input policy
resources, err := policy.resources(tx)
Expand All @@ -219,21 +204,19 @@ func (policy *Policy) createInDb(tx *sqlx.Tx) *ErrorResponse {
}
}
if len(missingResources) > 0 {
_ = tx.Rollback()
missingString := strings.Join(missingResources, ", ")
msg := fmt.Sprintf("failed to create policy: resources do not exist: %s", missingString)
return newErrorResponse(msg, 400, nil)
}
// try to insert relationships from this policy to all resources
stmt = multiInsertStmt("policy_resource(policy_id, resource_id)", len(resources))
stmt := multiInsertStmt("policy_resource(policy_id, resource_id)", len(resources))
policyResourceRows := []interface{}{}
for _, resource := range resources {
policyResourceRows = append(policyResourceRows, policyID)
policyResourceRows = append(policyResourceRows, resource.ID)
}
_, err = tx.Exec(stmt, policyResourceRows...)
if err != nil {
_ = tx.Rollback()
msg := fmt.Sprintf("failed to insert policy while linking resources: %s", err.Error())
return newErrorResponse(msg, 500, &err)
}
Expand All @@ -255,7 +238,6 @@ func (policy *Policy) createInDb(tx *sqlx.Tx) *ErrorResponse {
}
}
if len(missingRoles) > 0 {
_ = tx.Rollback()
missingString := strings.Join(missingRoles, ", ")
msg := fmt.Sprintf("failed to create policy: roles do not exist: %s", missingString)
return newErrorResponse(msg, 400, nil)
Expand All @@ -269,14 +251,41 @@ func (policy *Policy) createInDb(tx *sqlx.Tx) *ErrorResponse {
}
_, err = tx.Exec(stmt, policyRoleRows...)
if err != nil {
_ = tx.Rollback()
msg := fmt.Sprintf("failed to insert policy while linking roles: %s", err.Error())
return newErrorResponse(msg, 500, &err)
}

return nil
}

// createInDb writes out the policy to the database.
func (policy *Policy) createInDb(tx *sqlx.Tx) *ErrorResponse {
errResponse := policy.validate()
if errResponse != nil {
return errResponse
}

var policyID int
// TODO: make sure description works as expected
stmt := "INSERT INTO policy(name, description) VALUES ($1, $2) RETURNING id"
row := tx.QueryRowx(stmt, policy.Name, policy.Description)
err := row.Scan(&policyID)
if err != nil {
// should add more checking here to guarantee the correct error
// this should only fail because the policy was not unique. return error
// accordingly
msg := fmt.Sprintf("failed to insert policy: policy with this ID already exists: %s", policy.Name)
return newErrorResponse(msg, 409, &err)
}

errResponse = policy.addResourcesAndRoles(tx, policyID)
if errResponse != nil {
return errResponse
}

return nil
}

func (policy *Policy) deleteInDb(tx *sqlx.Tx) *ErrorResponse {
stmt := "DELETE FROM policy WHERE name = $1"
_, err := tx.Exec(stmt, policy.Name)
Expand All @@ -288,10 +297,46 @@ func (policy *Policy) deleteInDb(tx *sqlx.Tx) *ErrorResponse {
return nil
}

func (policy *Policy) overwriteInDb(tx *sqlx.Tx) *ErrorResponse {
errResponse := policy.deleteInDb(tx)
func (policy *Policy) updateInDb(tx *sqlx.Tx) *ErrorResponse {
// We do not allow updates to policy name (or id).

errResponse := policy.validate()
if errResponse != nil {
return errResponse
}

var policyID int
stmt := "UPDATE policy SET description = $1 WHERE name = $2 RETURNING id"
row := tx.QueryRowx(stmt, policy.Description, policy.Name)
err := row.Scan(&policyID)
switch {
case err == sql.ErrNoRows:
msg := fmt.Sprintf("failed to update policy: no policy found with id: %s", policy.Name)
return newErrorResponse(msg, 404, &err)
case err != nil:
msg := fmt.Sprintf("failed to update policy: update description failed: %s", err.Error())
return newErrorResponse(msg, 500, &err)
}

// First delete resources and roles that were previously attached to policy
stmt = "DELETE FROM policy_resource WHERE policy_id = $1"
_, err = tx.Exec(stmt, policyID)
if err != nil {
msg := fmt.Sprintf("database deletion from policy_resource failed: %s", err.Error())
return newErrorResponse(msg, 500, &err)
}
stmt = "DELETE FROM policy_role WHERE policy_id = $1"
_, err = tx.Exec(stmt, policyID)
if err != nil {
msg := fmt.Sprintf("database deletion from policy_role failed: %s", err.Error())
return newErrorResponse(msg, 500, &err)
}

// Now add the new resources and roles
errResponse = policy.addResourcesAndRoles(tx, policyID)
if errResponse != nil {
return errResponse
}
return policy.createInDb(tx)

return nil
}
6 changes: 3 additions & 3 deletions arborist/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,10 @@ func (resourceFromQuery *ResourceFromQuery) standardize() ResourceOut {
return resource
}

// FormatPathForDb takes a path from a resource in the database and transforms
// it to the front-end version of the resource path. Inverse of `formatDbPath`.
// FormatPathForDb takes a front-end version of a resource path and transforms
// it to its database version. Inverse of `formatDbPath`.
//
// formatDbPath("/a/b/c") == "a.b.c"
// FormatPathForDb("/a/b/c") == "a.b.c"
func FormatPathForDb(path string) string {
// -1 means replace everything
result := strings.TrimLeft(strings.Replace(UnderscoreEncode(path), "/", ".", -1), ".")
Expand Down
10 changes: 7 additions & 3 deletions arborist/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,6 @@ func (server *Server) handlePolicyCreate(w http.ResponseWriter, r *http.Request,

func (server *Server) handlePolicyOverwrite(w http.ResponseWriter, r *http.Request, body []byte) {
policy := &Policy{}
// uncomment this after 3.0.0
// policy.Name = mux.Vars(r)["policyID"]
err := json.Unmarshal(body, policy)
if err != nil {
msg := fmt.Sprintf("could not parse policy from JSON: %s", err.Error())
Expand All @@ -601,7 +599,13 @@ func (server *Server) handlePolicyOverwrite(w http.ResponseWriter, r *http.Reque
_ = response.write(w, r)
return
}
errResponse := transactify(server.db, policy.overwriteInDb)
// Overwrite policy name from json with policy name from query arg.
// After 3.0.0, when PUT /policy is deprecated and only PUT /policy/{policyID} is allowed,
// can remove the !="" check. For now, if policy name not found in url, default to name in json.
if mux.Vars(r)["policyID"] != "" {
policy.Name = mux.Vars(r)["policyID"]
}
errResponse := transactify(server.db, policy.updateInDb)
if errResponse != nil {
errResponse.log.write(server.logger)
_ = errResponse.write(w, r)
Expand Down

0 comments on commit db16b60

Please sign in to comment.