From 443dc2879a1ccf6722214c1ad1c7dc50d7c0c8ac Mon Sep 17 00:00:00 2001 From: John McCann Date: Thu, 16 Sep 2021 18:31:34 -0700 Subject: [PATCH] feat(authMapping): account for expired policies --- arborist/auth.go | 4 ++-- arborist/server_test.go | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/arborist/auth.go b/arborist/auth.go index beb474a..609a9fb 100644 --- a/arborist/auth.go +++ b/arborist/auth.go @@ -664,12 +664,12 @@ func authMapping(db *sqlx.DB, username string) (AuthMapping, *ErrorResponse) { ( SELECT usr_policy.policy_id FROM usr INNER JOIN usr_policy ON usr_policy.usr_id = usr.id - WHERE usr.name = $1 + WHERE usr.name = $1 AND (usr_policy.expires_at IS NULL OR NOW() < usr_policy.expires_at) UNION SELECT grp_policy.policy_id FROM usr INNER JOIN usr_grp ON usr_grp.usr_id = usr.id INNER JOIN grp_policy ON grp_policy.grp_id = usr_grp.grp_id - WHERE usr.name = $1 + WHERE usr.name = $1 AND (usr_grp.expires_at IS NULL OR NOW() < usr_grp.expires_at) UNION SELECT grp_policy.policy_id FROM grp INNER JOIN grp_policy ON grp_policy.grp_id = grp.id diff --git a/arborist/server_test.go b/arborist/server_test.go index 4b6bdc2..f2168d4 100644 --- a/arborist/server_test.go +++ b/arborist/server_test.go @@ -2804,6 +2804,24 @@ func TestServer(t *testing.T) { } }) + t.Run("GET_expiredPolicy", func(t *testing.T) { + expiredTimestamp := time.Now().Add(time.Duration(-1) * time.Minute).Format(time.RFC3339) + grantExpiringUserPolicy(t, username, policyName, expiredTimestamp) + w := httptest.NewRecorder() + url := fmt.Sprintf("/auth/mapping?username=%s", username) + req := newRequest("GET", url, nil) + handler.ServeHTTP(w, req) + + result := make(map[string][]arborist.Action) + err = json.Unmarshal(w.Body.Bytes(), &result) + if err != nil { + httpError(t, w, "couldn't read response from auth mapping") + } + msg := fmt.Sprintf("result contains resource %s corresponding to expired policy %s", resourcePath, policyName) + assert.NotContains(t, result, resourcePath, msg) + grantUserPolicy(t, username, policyName) + }) + t.Run("POST", func(t *testing.T) { w := httptest.NewRecorder() body := []byte(fmt.Sprintf(`{"username": "%s"}`, username))