Skip to content

Commit

Permalink
Merge 09c4c59 into c58e30b
Browse files Browse the repository at this point in the history
  • Loading branch information
vpsx committed Sep 17, 2019
2 parents c58e30b + 09c4c59 commit 783cfa0
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 15 deletions.
130 changes: 121 additions & 9 deletions arborist/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,7 @@ func (policy *Policy) UnmarshalJSON(data []byte) error {
return err
}

// delete fields which should be ignored in user input
delete(fields, "ID")
// uncomment this after 3.0.0 and lowercase the ID field in optionalFields too
// delete(fields, "id")

optionalFields := map[string]struct{}{
"ID": struct{}{},
"description": struct{}{},
}
err = validateJSON("policy", policy, fields, optionalFields)
Expand Down Expand Up @@ -169,6 +163,9 @@ func (policy *Policy) roles(tx *sqlx.Tx) ([]RoleFromQuery, error) {
// looking at the database. This includes that the policy must contain at least
// one resource and at least one role.
func (policy *Policy) validate() *ErrorResponse {
if len(policy.Name) == 0 {
return newErrorResponse("policy name cannot be empty string", 400, nil)
}
// Resources and roles must be non-empty
if len(policy.ResourcePaths) == 0 {
return newErrorResponse("no resource paths specified", 400, nil)
Expand Down Expand Up @@ -203,6 +200,7 @@ func (policy *Policy) createInDb(tx *sqlx.Tx) *ErrorResponse {
// `resources` is a list of looked-up resources which appear in the input policy
resources, err := policy.resources(tx)
if err != nil {
_ = tx.Rollback()
msg := fmt.Sprintf("database call for resources failed: %s", err.Error())
return newErrorResponse(msg, 500, &err)
}
Expand Down Expand Up @@ -240,6 +238,7 @@ func (policy *Policy) createInDb(tx *sqlx.Tx) *ErrorResponse {

roles, err := policy.roles(tx)
if err != nil {
_ = tx.Rollback()
msg := fmt.Sprintf("database call for roles failed: %s", err.Error())
return newErrorResponse(msg, 500, &err)
}
Expand Down Expand Up @@ -288,10 +287,123 @@ func (policy *Policy) deleteInDb(tx *sqlx.Tx) *ErrorResponse {
return nil
}

func (policy *Policy) overwriteInDb(tx *sqlx.Tx) *ErrorResponse {
errResponse := policy.deleteInDb(tx)
func (policy *Policy) updateInDb(tx *sqlx.Tx) *ErrorResponse {
// We do not allow updates to policy name (or id).

errResponse := policy.validate()
if errResponse != nil {
return errResponse
}
return policy.createInDb(tx)

var policyID int
stmt := "SELECT id FROM policy WHERE name = $1"
err := tx.Get(&policyID, stmt, policy.Name)
if err != nil {
_ = tx.Rollback()
msg := fmt.Sprintf("failed to update policy: no policy found with id: %s", policy.Name)
return newErrorResponse(msg, 404, &err)
}

stmt = "UPDATE policy SET description = $1 WHERE id = $2"
_, err = tx.Exec(stmt, policy.Description, policyID)
if err != nil {
_ = tx.Rollback()
msg := fmt.Sprintf("failed to update policy: update description failed: %s", err.Error())
return newErrorResponse(msg, 500, &err)
}

// First delete resources and roles that were previously attached to policy
stmt = "DELETE FROM policy_resource WHERE policy_id = $1"
_, err = tx.Exec(stmt, policyID)
if err != nil {
_ = tx.Rollback()
msg := fmt.Sprintf("database deletion from policy_resource failed: %s", err.Error())
return newErrorResponse(msg, 500, &err)
}
stmt = "DELETE FROM policy_role WHERE policy_id = $1"
_, err = tx.Exec(stmt, policyID)
if err != nil {
_ = tx.Rollback()
msg := fmt.Sprintf("database deletion from policy_role failed: %s", err.Error())
return newErrorResponse(msg, 500, &err)
}


// `resources` is a list of looked-up resources which appear in the input policy
resources, err := policy.resources(tx)
if err != nil {
_ = tx.Rollback()
msg := fmt.Sprintf("database call for resources failed: %s", err.Error())
return newErrorResponse(msg, 500, &err)
}
// make sure all resources for new policy exist in DB
resourceSet := make(map[string]struct{})
for _, resource := range resources {
path := formatDbPath(resource.Path)
resourceSet[path] = struct{}{}
}
missingResources := []string{}
for _, path := range policy.ResourcePaths {
if _, exists := resourceSet[path]; !exists {
missingResources = append(missingResources, path)
}
}
if len(missingResources) > 0 {
_ = tx.Rollback()
missingString := strings.Join(missingResources, ", ")
msg := fmt.Sprintf("failed to create policy: resources do not exist: %s", missingString)
return newErrorResponse(msg, 400, nil)
}
// try to insert relationships from this policy to all resources
stmt = multiInsertStmt("policy_resource(policy_id, resource_id)", len(resources))
policyResourceRows := []interface{}{}
for _, resource := range resources {
policyResourceRows = append(policyResourceRows, policyID)
policyResourceRows = append(policyResourceRows, resource.ID)
}
_, err = tx.Exec(stmt, policyResourceRows...)
if err != nil {
_ = tx.Rollback()
msg := fmt.Sprintf("failed to insert policy while linking resources: %s", err.Error())
return newErrorResponse(msg, 500, &err)
}

roles, err := policy.roles(tx)
if err != nil {
_ = tx.Rollback()
msg := fmt.Sprintf("database call for roles failed: %s", err.Error())
return newErrorResponse(msg, 500, &err)
}
// make sure all roles for new policy exist in DB
roleSet := make(map[string]struct{})
for _, role := range roles {
roleSet[role.Name] = struct{}{}
}
missingRoles := []string{}
for _, role := range policy.RoleIDs {
if _, exists := roleSet[role]; !exists {
missingRoles = append(missingRoles, role)
}
}
if len(missingRoles) > 0 {
_ = tx.Rollback()
missingString := strings.Join(missingRoles, ", ")
msg := fmt.Sprintf("failed to create policy: roles do not exist: %s", missingString)
return newErrorResponse(msg, 400, nil)
}
// try to insert relationships from this policy to all roles
stmt = multiInsertStmt("policy_role(policy_id, role_id)", len(roles))
policyRoleRows := []interface{}{}
for _, role := range roles {
policyRoleRows = append(policyRoleRows, policyID)
policyRoleRows = append(policyRoleRows, role.ID)
}
_, err = tx.Exec(stmt, policyRoleRows...)
if err != nil {
_ = tx.Rollback()
msg := fmt.Sprintf("failed to insert policy while linking roles: %s", err.Error())
return newErrorResponse(msg, 500, &err)
}

return nil
}
6 changes: 3 additions & 3 deletions arborist/resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,10 @@ func (resourceFromQuery *ResourceFromQuery) standardize() ResourceOut {
return resource
}

// FormatPathForDb takes a path from a resource in the database and transforms
// it to the front-end version of the resource path. Inverse of `formatDbPath`.
// FormatPathForDb takes a front-end version of a resource path and transforms
// it to its database version. Inverse of `formatDbPath`.
//
// formatDbPath("/a/b/c") == "a.b.c"
// FormatPathForDb("/a/b/c") == "a.b.c"
func FormatPathForDb(path string) string {
// -1 means replace everything
result := strings.TrimLeft(strings.Replace(UnderscoreEncode(path), "/", ".", -1), ".")
Expand Down
10 changes: 7 additions & 3 deletions arborist/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -591,8 +591,6 @@ func (server *Server) handlePolicyCreate(w http.ResponseWriter, r *http.Request,

func (server *Server) handlePolicyOverwrite(w http.ResponseWriter, r *http.Request, body []byte) {
policy := &Policy{}
// uncomment this after 3.0.0
// policy.Name = mux.Vars(r)["policyID"]
err := json.Unmarshal(body, policy)
if err != nil {
msg := fmt.Sprintf("could not parse policy from JSON: %s", err.Error())
Expand All @@ -601,7 +599,13 @@ func (server *Server) handlePolicyOverwrite(w http.ResponseWriter, r *http.Reque
_ = response.write(w, r)
return
}
errResponse := transactify(server.db, policy.overwriteInDb)
// Overwrite policy name from json with policy name from query arg.
// After 3.0.0, when PUT /policy is deprecated and only PUT /policy/{policyID} is allowed,
// can remove the !="" check. For now, if policy name not found in url, default to name in json.
if mux.Vars(r)["policyID"] != "" {
policy.Name = mux.Vars(r)["policyID"]
}
errResponse := transactify(server.db, policy.updateInDb)
if errResponse != nil {
errResponse.log.write(server.logger)
_ = errResponse.write(w, r)
Expand Down

0 comments on commit 783cfa0

Please sign in to comment.