Skip to content

Commit

Permalink
Support "client credentials" tokens (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
paulineribeyre committed Aug 12, 2022
1 parent d35cee0 commit 0f4974c
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 33 deletions.
7 changes: 6 additions & 1 deletion src/requestor/routes/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,12 @@ async def create_request(
if not data.get("username"):
logger.debug("No username provided in body, using token username")
token_claims = await auth.get_token_claims()
token_username = token_claims["context"]["user"]["name"]
token_username = token_claims.get("context", {}).get("user", {}).get("name")
if not token_username:
raise HTTPException(
HTTP_400_BAD_REQUEST,
"Must provide a username in the request body or token",
)
logger.debug(f"Got username from token: {token_username}")
data["username"] = token_username

Expand Down
17 changes: 15 additions & 2 deletions src/requestor/routes/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from starlette.status import (
HTTP_200_OK,
HTTP_400_BAD_REQUEST,
HTTP_403_FORBIDDEN,
HTTP_404_NOT_FOUND,
)

Expand Down Expand Up @@ -105,6 +106,8 @@ async def list_requests(

# get the resources the current user has access to see
token_claims = await auth.get_token_claims()
# TODO update this endpoint to accept client tokens. We need to get the
# auth mapping for the client instead of the user
username = token_claims["context"]["user"]["name"]
authz_mapping = await api_request.app.arborist_client.auth_mapping(username)
authorized_resource_paths = [
Expand Down Expand Up @@ -171,7 +174,12 @@ async def list_user_requests(api_request: Request, auth=Depends(Auth)) -> dict:
# their own requests.
filter_dict, active = populate_filters_from_query_params(api_request.query_params)
token_claims = await auth.get_token_claims()
username = token_claims["context"]["user"]["name"]
username = token_claims.get("context", {}).get("user", {}).get("name")
if not username:
raise HTTPException(
HTTP_403_FORBIDDEN,
"This endpoint does not support tokens that are not linked to a user",
)
logger.debug(f"Getting requests for user '{username}' with active = '{active}'")
user_requests = await get_filtered_requests(
# if we only want active requests, filter out requests in a final status
Expand Down Expand Up @@ -237,7 +245,12 @@ async def check_user_resource_paths(
# no authz checks because we assume the current user can read
# their own requests.
token_claims = await auth.get_token_claims()
username = token_claims["context"]["user"]["name"]
username = token_claims.get("context", {}).get("user", {}).get("name")
if not username:
raise HTTPException(
HTTP_403_FORBIDDEN,
"This endpoint does not support tokens that are not linked to a user",
)
user_requests = await get_filtered_requests(username, draft=False, final=False)
positive_requests = [r for r in user_requests if not r.revoke]
existing_policies = await arborist.list_policies(
Expand Down
35 changes: 33 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,39 @@ def list_roles_patcher():
role_patch.stop()


@pytest.fixture(autouse=True, scope="function")
def access_token_patcher(client, request):
@pytest.fixture(autouse=True, scope="function", params=["user_token", "client_token"])
def access_token_user_client_patcher(client, request):
"""
The `access_token` function will return first a token linked to a test
user, then a token linked to a test client.
"""

async def get_access_token(*args, **kwargs):
if request.param == "user_token":
return {"sub": "1", "context": {"user": {"name": "requestor_user"}}}
if request.param == "client_token":
return {"context": {}, "azp": "test-client-id"}

access_token_mock = MagicMock()
access_token_mock.return_value = get_access_token

access_token_patch = patch("requestor.auth.access_token", access_token_mock)
access_token_patch.start()

yield access_token_mock

access_token_patch.stop()


@pytest.fixture(scope="function")
def access_token_user_only_patcher(client, request):
"""
The `access_token` function will return a token linked to a test user.
This fixture should be used explicitely instead of the automatic
`access_token_user_client_patcher` fixture for endpoints that do not
support client tokens.
"""

async def get_access_token(*args, **kwargs):
return {"sub": "1", "context": {"user": {"name": "requestor_user"}}}

Expand Down
2 changes: 1 addition & 1 deletion tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_backoff_retry(client):
assert mock_requests.post.call_count == config["DEFAULT_MAX_RETRIES"]


def test_create_request_failure_revert(client):
def test_create_request_failure_revert(client, access_token_user_only_patcher):
"""
If something goes wrong during an external call, access should not be
granted, the request should not be created and we should get a 500.
Expand Down
6 changes: 4 additions & 2 deletions tests/test_manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_create_request_with_unallowed_params(client, data):
assert data["err_msg"] in res.json()["detail"]


def test_create_request_without_username(client):
def test_create_request_without_username(client, access_token_user_only_patcher):
"""
When a username is not provided in the body, the request is created
using the username from the provided access token.
Expand Down Expand Up @@ -146,7 +146,9 @@ def test_create_duplicate_request(client):
assert res.status_code == 201, res.text


def test_create_request_without_access(client, mock_arborist_requests):
def test_create_request_without_access(
client, mock_arborist_requests, access_token_user_only_patcher
):
fake_jwt = "1.2.3"
mock_arborist_requests(authorized=False)

Expand Down
8 changes: 6 additions & 2 deletions tests/test_manage_resource_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def test_create_request_with_resource_paths_and_role_ids(
},
],
)
def test_create_request_without_username(client, data):
def test_create_request_without_username(client, data, access_token_user_only_patcher):
"""
When a username is not provided in the body, the request is created
using the username from the provided access token.
Expand Down Expand Up @@ -441,7 +441,11 @@ def test_create_duplicate_request(client, data):
],
)
def test_create_request_without_access(
client, mock_arborist_requests, list_roles_patcher, data
client,
mock_arborist_requests,
list_roles_patcher,
data,
access_token_user_only_patcher,
):
fake_jwt = "1.2.3"
mock_arborist_requests(authorized=False)
Expand Down
71 changes: 55 additions & 16 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,9 @@
from requestor.config import config


def test_create_get_and_list_request(client):
def test_create_and_get_request(client):
fake_jwt = "1.2.3"

# list requests: empty
res = client.get("/request", headers={"Authorization": f"bearer {fake_jwt}"})
assert res.status_code == 200
assert res.json() == []

# create a request
data = {
"username": "requestor_user",
Expand Down Expand Up @@ -43,6 +38,42 @@ def test_create_get_and_list_request(client):
assert res.status_code == 200, res.text
assert res.json() == request_data


def test_create_and_list_request(client, access_token_user_only_patcher):
fake_jwt = "1.2.3"

# list requests: empty
res = client.get("/request", headers={"Authorization": f"bearer {fake_jwt}"})
assert res.status_code == 200
assert res.json() == []

# create a request
data = {
"username": "requestor_user",
"policy_id": "test-policy",
"resource_id": "uniqid",
"resource_display_name": "My Resource",
}
res = client.post(
"/request", json=data, headers={"Authorization": f"bearer {fake_jwt}"}
)
assert res.status_code == 201, res.text
request_data = res.json()
request_id = request_data.get("request_id")
assert request_id, "POST /request did not return a request_id"
assert request_data == {
"request_id": request_id,
"username": data["username"],
"policy_id": data["policy_id"],
"resource_id": data["resource_id"],
"resource_display_name": data["resource_display_name"],
"status": config["DEFAULT_INITIAL_STATUS"],
# just ensure revoke, created_time and updated_time are there:
"revoke": False,
"created_time": request_data["created_time"],
"updated_time": request_data["updated_time"],
}

# list requests
res = client.get("/request", headers={"Authorization": f"bearer {fake_jwt}"})
assert res.status_code == 200, res.text
Expand Down Expand Up @@ -92,7 +123,7 @@ def test_get_request_without_access(client, mock_arborist_requests):
assert not_found_err == unauthorized_err


def test_get_filtered_requests(client):
def test_get_filtered_requests(client, access_token_user_only_patcher):

fake_jwt = "1.2.3"
filtered_requests = []
Expand Down Expand Up @@ -173,7 +204,7 @@ def test_get_filtered_requests(client):
assert res.status_code == 400, res.text


def test_get_user_requests(client):
def test_get_user_requests(client, access_token_user_only_patcher):
fake_jwt = "1.2.3"

# create a request for the current user
Expand Down Expand Up @@ -205,7 +236,7 @@ def test_get_user_requests(client):
assert res.status_code == 401, res.text


def test_get_active_user_requests(client):
def test_get_active_user_requests(client, access_token_user_only_patcher):
fake_jwt = "1.2.3"

# create a request with a DRAFT status
Expand Down Expand Up @@ -260,7 +291,7 @@ def test_get_active_user_requests(client):
assert res.json() == [active_request1, active_request2]


def test_get_filtered_user_requests(client):
def test_get_filtered_user_requests(client, access_token_user_only_patcher):
fake_jwt = "1.2.3"
filtered_requests = []

Expand Down Expand Up @@ -337,7 +368,7 @@ def test_get_filtered_user_requests(client):
assert res.status_code == 400, res.text


def test_list_requests_with_access(client):
def test_list_requests_with_access(client, access_token_user_only_patcher):
fake_jwt = "1.2.3"

# create requests
Expand Down Expand Up @@ -403,7 +434,9 @@ def test_list_requests_with_access(client):
},
],
)
def test_check_user_resource_paths_prefixes(client, list_policies_patcher, test_data):
def test_check_user_resource_paths_prefixes(
client, list_policies_patcher, test_data, access_token_user_only_patcher
):
"""
Test if having requested access to the resource path in
test_data["resource_path"] means having requested access to
Expand Down Expand Up @@ -448,7 +481,9 @@ def test_check_user_resource_paths_prefixes(client, list_policies_patcher, test_
}
],
)
def test_check_user_resource_paths_multiple(client, list_policies_patcher, test_data):
def test_check_user_resource_paths_multiple(
client, list_policies_patcher, test_data, access_token_user_only_patcher
):
fake_jwt = "1.2.3"
expected_matches = {
"/a/b": True,
Expand Down Expand Up @@ -481,7 +516,7 @@ def test_check_user_resource_paths_multiple(client, list_policies_patcher, test_
assert res.json() == expected_matches


def test_check_user_resource_paths_username(client):
def test_check_user_resource_paths_username(client, access_token_user_only_patcher):
fake_jwt = "1.2.3"

resource_path_to_match = "/a/b"
Expand Down Expand Up @@ -534,7 +569,9 @@ def test_check_user_resource_paths_username(client):
},
],
)
def test_check_user_resource_paths_status(client, list_policies_patcher, test_data):
def test_check_user_resource_paths_status(
client, list_policies_patcher, test_data, access_token_user_only_patcher
):
fake_jwt = "1.2.3"

# create a request with the status to test
Expand Down Expand Up @@ -588,7 +625,9 @@ def test_check_user_resource_paths_status(client, list_policies_patcher, test_da
},
],
)
def test_check_permissions_mismatch(client, list_policies_patcher, test_data):
def test_check_permissions_mismatch(
client, list_policies_patcher, test_data, access_token_user_only_patcher
):
fake_jwt = "1.2.3"

# create a request with an active status
Expand Down
20 changes: 13 additions & 7 deletions tests/test_query_resource_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from requestor.config import config


def test_create_get_and_list_request(client):
def test_create_get_and_list_request(client, access_token_user_only_patcher):
fake_jwt = "1.2.3"

# list requests: empty
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_get_request_without_access(client, mock_arborist_requests):
assert not_found_err == unauthorized_err


def test_get_user_requests(client):
def test_get_user_requests(client, access_token_user_only_patcher):
fake_jwt = "1.2.3"

# create a request for the current user
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_get_user_requests(client):
assert res.status_code == 401, res.text


def test_list_requests_with_access(client):
def test_list_requests_with_access(client, access_token_user_only_patcher):
fake_jwt = "1.2.3"

# create requests
Expand Down Expand Up @@ -190,7 +190,9 @@ def test_list_requests_with_access(client):
},
],
)
def test_check_user_resource_paths_prefixes(client, list_policies_patcher, test_data):
def test_check_user_resource_paths_prefixes(
client, list_policies_patcher, test_data, access_token_user_only_patcher
):
"""
Test if having requested access to the resource path in
test_data["resource_path"] means having requested access to
Expand Down Expand Up @@ -227,7 +229,9 @@ def test_check_user_resource_paths_prefixes(client, list_policies_patcher, test_


@pytest.mark.parametrize("test_data", [{"resource_paths": ["/a/b", "/c"]}])
def test_check_user_resource_paths_multiple(client, list_policies_patcher, test_data):
def test_check_user_resource_paths_multiple(
client, list_policies_patcher, test_data, access_token_user_only_patcher
):
fake_jwt = "1.2.3"
existing_resource_paths = test_data["resource_paths"]
expected_matches = {
Expand Down Expand Up @@ -262,7 +266,7 @@ def test_check_user_resource_paths_multiple(client, list_policies_patcher, test_
assert res.json() == expected_matches


def test_check_user_resource_paths_username(client):
def test_check_user_resource_paths_username(client, access_token_user_only_patcher):
fake_jwt = "1.2.3"

resource_path_to_match = "/a/b"
Expand Down Expand Up @@ -312,7 +316,9 @@ def test_check_user_resource_paths_username(client):
},
],
)
def test_check_user_resource_paths_status(client, list_policies_patcher, test_data):
def test_check_user_resource_paths_status(
client, list_policies_patcher, test_data, access_token_user_only_patcher
):
fake_jwt = "1.2.3"
resource_path_to_match = test_data["resource_path"]

Expand Down

0 comments on commit 0f4974c

Please sign in to comment.