Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ test = [
"freezegun",
"mypy",
"nox",
"pyflakes",
"pyflakes == 3.2.0", # 3.3.0 causes some grief with a new warning.
"pylint",
# "pylint[spelling]", ## TODO
"pytest",
Expand Down
24 changes: 12 additions & 12 deletions src/planet_auth/auth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _get_typename_map(cls):
return cls._typename_map

@classmethod
def from_dict(cls, config_data: dict) -> AuthClientConfig:
def from_dict(cls, config_data: Dict) -> AuthClientConfig:
"""
Create a AuthClientConfig from a configuration dictionary.
Returns:
Expand Down Expand Up @@ -158,7 +158,7 @@ def from_file(file_path, storage_provider: Optional[ObjectStorageProvider] = Non

@classmethod
@abstractmethod
def meta(cls) -> dict:
def meta(cls) -> Dict:
"""
Return a dictionary of metadata.
The meta dictionary provides a place to store information that is
Expand Down Expand Up @@ -269,7 +269,7 @@ def login(
implementations should raise an exception for all login errors.
"""

def device_login_initiate(self, **kwargs) -> dict:
def device_login_initiate(self, **kwargs) -> Dict:
"""
Initiate the process to login a device with limited UI capabilities.
The returned dictionary should contain information for the application
Expand All @@ -286,7 +286,7 @@ def device_login_initiate(self, **kwargs) -> dict:
"""
raise AuthClientException(message="Device login is not supported for the current authentication mechanism")

def device_login_complete(self, initiated_login_data: dict) -> Credential:
def device_login_complete(self, initiated_login_data: Dict) -> Credential:
"""
Complete a login process that was initiated by a call to `device_login_initiate()`.

Expand Down Expand Up @@ -319,7 +319,7 @@ def refresh(self, refresh_token: str, requested_scopes: List[str]) -> Credential
"""
raise AuthClientException(message="Refresh not implemented for the current authentication mechanism")

def validate_access_token_remote(self, access_token: str) -> dict:
def validate_access_token_remote(self, access_token: str) -> Dict:
"""
Validate an access token with the authorization server.
Parameters:
Expand All @@ -337,7 +337,7 @@ def validate_access_token_remote(self, access_token: str) -> dict:

def validate_access_token_local(
self, access_token: str, required_audience: str = None, scopes_anyof: list = None
) -> dict:
) -> Dict:
"""
Validate an access token locally. While the validation is local,
the authorization server may still may contacted to obtain signing
Expand Down Expand Up @@ -382,7 +382,7 @@ def validate_access_token_local(
message="Access token validation is not implemented for the current authentication mechanism"
)

def validate_id_token_remote(self, id_token: str) -> dict:
def validate_id_token_remote(self, id_token: str) -> Dict:
"""
Validate an ID token with the authorization server.
Parameters:
Expand All @@ -394,7 +394,7 @@ def validate_id_token_remote(self, id_token: str) -> dict:
message="ID token validation is not implemented for the current authentication mechanism"
)

def validate_id_token_local(self, id_token: str) -> dict:
def validate_id_token_local(self, id_token: str) -> Dict:
"""
Validate an ID token locally. The authorization server may still be
called to obtain signing keys for validation. Signing keys will be
Expand All @@ -408,7 +408,7 @@ def validate_id_token_local(self, id_token: str) -> dict:
message="ID token validation is not implemented for the current authentication mechanism"
)

def validate_refresh_token_remote(self, refresh_token: str) -> dict:
def validate_refresh_token_remote(self, refresh_token: str) -> Dict:
"""
Validate a refresh token with the authorization server.
Parameters:
Expand Down Expand Up @@ -440,7 +440,7 @@ def revoke_refresh_token(self, refresh_token: str):
message="Refresh token revocation is not implemented for the current authentication mechanism"
)

def userinfo_from_access_token(self, access_token: str) -> dict:
def userinfo_from_access_token(self, access_token: str) -> Dict:
"""
Look up user information from the auth server using the access token.
Parameters:
Expand All @@ -450,7 +450,7 @@ def userinfo_from_access_token(self, access_token: str) -> dict:
message="User information lookup is not implemented for the current authentication mechanism"
)

def oidc_discovery(self) -> dict:
def oidc_discovery(self) -> Dict:
"""
Query the authorization server's OIDC discovery endpoint for server information.
Returns:
Expand All @@ -460,7 +460,7 @@ def oidc_discovery(self) -> dict:
message="OIDC discovery is not implemented for the current authentication mechanism."
)

# def oauth_discovery(self) -> dict:
# def oauth_discovery(self) -> Dict:
# """
# Query the authorization server's OAuth2 discovery endpoint for server information.
# Returns:
Expand Down
40 changes: 30 additions & 10 deletions src/planet_auth/oidc/api_clients/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,25 @@
# limitations under the License.

from abc import ABC
from requests import Session
from requests import Session, Response
from requests.adapters import HTTPAdapter
from requests.auth import AuthBase
from typing import Callable, Dict, Optional, Tuple
from urllib3.util.retry import Retry

from planet_auth.auth_client import AuthClientException
from planet_auth.constants import X_PLANET_APP_HEADER, X_PLANET_APP
from planet_auth.util import parse_content_type

EnricherPayloadType = Dict
# EnricherAudType = str
EnricherReturnType = Tuple[Dict, Optional[AuthBase]]
EnricherFuncType = Callable[[EnricherPayloadType, str], EnricherReturnType]

_RequestAuthType = AuthBase
_RequestParamsType = Dict # Requests allows a lot more, but constrain our use.
_RequestResponseType = Response


class OidcApiClient(ABC):
"""
Expand All @@ -34,6 +44,8 @@ class OidcApiClient(ABC):
# Generally, this will be some combination of the client ID and secret
# and may be a header or payload adjustment. But sometimes, we just
# need to use an Authorization header.
# TODO: dog-food - use our own RequestAuthenticator like we do for the
# static API key auth client
class TokenBearerAuth(AuthBase):
def __init__(self, token):
self._token = token
Expand All @@ -42,7 +54,7 @@ def __call__(self, r):
r.headers["Authorization"] = "Bearer " + self._token
return r

def __init__(self, endpoint_uri):
def __init__(self, endpoint_uri: str):
self._endpoint_uri = endpoint_uri

retry_strategy = Retry(total=3, backoff_factor=1, status_forcelist=[429], allowed_methods=["POST", "GET"])
Expand All @@ -51,7 +63,7 @@ def __init__(self, endpoint_uri):
self._session.mount("https://", adapter)
# self._session.mount("http://", adapter)

def __check_http_error(self, response):
def __check_http_error(self, response: _RequestResponseType) -> None:
if not response.ok:
raise OidcApiClientException(
message="HTTP error from OIDC endpoint at {}: {}: {}".format(
Expand All @@ -60,7 +72,7 @@ def __check_http_error(self, response):
raw_response=response,
)

def __check_oidc_payload_json_error(self, response):
def __check_oidc_payload_json_error(self, response: _RequestResponseType) -> None:
if response.content:
ct = parse_content_type(response.headers.get("content-type"))
if not ct["content-type"] == "application/json":
Expand Down Expand Up @@ -89,7 +101,7 @@ def __check_oidc_payload_json_error(self, response):
)

@staticmethod
def __checked_json_response(response):
def __checked_json_response(response: _RequestResponseType) -> Dict:
json_response = None
if response.content:
ct = parse_content_type(response.headers.get("content-type"))
Expand All @@ -106,13 +118,15 @@ def __checked_json_response(response):
)
return json_response

def __check_response(self, response):
def __check_response(self, response: _RequestResponseType) -> None:
# Check for the json error first so we throw a more specific parsed
# error if we understand it, regardless of HTTP status code.
self.__check_oidc_payload_json_error(response)
self.__check_http_error(response)

def _checked_get(self, params, request_auth):
def _checked_get(
self, params: Optional[_RequestParamsType], request_auth: Optional[_RequestAuthType]
) -> _RequestResponseType:
response = self._session.get(
self._endpoint_uri,
params=params,
Expand All @@ -122,7 +136,9 @@ def _checked_get(self, params, request_auth):
self.__check_response(response)
return response

def _checked_post(self, params, request_auth):
def _checked_post(
self, params: Optional[_RequestParamsType], request_auth: Optional[_RequestAuthType]
) -> _RequestResponseType:
response = self._session.post(
self._endpoint_uri,
# Note: is the data/params crossing confusing? This was born out
Expand All @@ -139,10 +155,14 @@ def _checked_post(self, params, request_auth):
self.__check_response(response)
return response

def _checked_post_json_response(self, params, request_auth):
def _checked_post_json_response(
self, params: Optional[_RequestParamsType], request_auth: Optional[_RequestAuthType]
) -> Dict:
return self.__checked_json_response(self._checked_post(params, request_auth))

def _checked_get_json_response(self, params, request_auth):
def _checked_get_json_response(
self, params: Optional[_RequestParamsType], request_auth: Optional[_RequestAuthType]
) -> Dict:
return self.__checked_json_response(self._checked_get(params, request_auth))


Expand Down
14 changes: 8 additions & 6 deletions src/planet_auth/oidc/api_clients/authorization_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from http import HTTPStatus
from urllib.parse import urlparse, parse_qs, urlencode
from typing import List
from typing import Dict, List, Optional
from webbrowser import open_new

import planet_auth.logging.auth_logger
Expand Down Expand Up @@ -125,7 +125,9 @@ class AuthorizationApiClient:
interactive authentication can be performed.
"""

def __init__(self, authorization_uri=None, authorization_callback_acknowledgement_response_body=None):
def __init__(
self, authorization_uri: str, authorization_callback_acknowledgement_response_body: Optional[str] = None
):
"""
Create a new Authorization API Client.
"""
Expand All @@ -146,8 +148,8 @@ def prep_pkce_auth_payload(
requested_scopes: List[str],
requested_audiences: List[str],
pkce_code_challenge: str,
extra: dict,
) -> dict:
extra: Dict,
) -> Dict:
"""
Prepare the payload needed to make an authorization request to an
OAuth authorization endpoint. This will usually be used to construct
Expand Down Expand Up @@ -218,7 +220,7 @@ def authcode_from_pkce_auth_request_with_browser_and_callback_listener(
requested_scopes: List[str],
requested_audiences: List[str],
pkce_code_challenge: str,
extra: dict,
extra: Dict,
) -> str:
"""
Request an authorization code by launching a web browser directed to the
Expand Down Expand Up @@ -309,7 +311,7 @@ def authcode_from_pkce_auth_request_with_tty_input(
requested_scopes: List[str],
requested_audiences: List[str],
pkce_code_challenge: str,
extra: dict,
extra: Dict,
) -> str:
"""
Request an authorization code by prompting the user to visit a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from planet_auth.oidc.api_clients.api_client import OidcApiClient, OidcApiClientException
from typing import Dict, List, Optional

from planet_auth.oidc.api_clients.api_client import (
OidcApiClient,
OidcApiClientException,
EnricherFuncType,
_RequestParamsType,
_RequestAuthType,
)


class DeviceAuthorizationApiException(OidcApiClientException):
Expand All @@ -28,11 +36,16 @@ class DeviceAuthorizationApiClient(OidcApiClient):
All invalid responses or error responses will result in an exception.
"""

def __init__(self, device_authorization_uri=None):
def __init__(self, device_authorization_uri: str):
super().__init__(endpoint_uri=device_authorization_uri)

@staticmethod
def _prep_device_code_request_payload(client_id, requested_scopes, requested_audiences, extra):
def _prep_device_code_request_payload(
client_id,
requested_scopes: Optional[List[str]],
requested_audiences: Optional[List[str]],
extra: Optional[Dict],
) -> Dict:
if extra is None:
extra = {}
# "None" is pythonic, and does not mean anything to OAuth APIs.
Expand All @@ -50,7 +63,7 @@ def _prep_device_code_request_payload(client_id, requested_scopes, requested_aud
return data

@staticmethod
def _check_device_auth_response(json_response):
def _check_device_auth_response(json_response: Dict) -> Dict:
# Protocol endpoint specific response checks
if not json_response.get("device_code"):
raise DeviceAuthorizationApiException(
Expand All @@ -71,11 +84,20 @@ def _check_device_auth_response(json_response):
# verification_uri_complete and interval are optional under the spec, so we don't force them to be present.
return json_response

def _checked_request_device_code_call(self, request_params, request_auth):
def _checked_request_device_code_call(
self, request_params: _RequestParamsType, request_auth: Optional[_RequestAuthType]
) -> Dict:
json_response = self._checked_post_json_response(request_params, request_auth)
return self._check_device_auth_response(json_response)

def request_device_code(self, client_id: str, requested_scopes, requested_audiences, auth_enricher, extra):
def request_device_code(
self,
client_id: str,
requested_scopes: Optional[List[str]],
requested_audiences: Optional[List[str]],
auth_enricher: Optional[EnricherFuncType],
extra,
) -> Dict:
request_params = self._prep_device_code_request_payload(
client_id=client_id,
requested_scopes=requested_scopes,
Expand Down
Loading