From 6d1ecf0698f5702bbdc7bfa51fea84ffab9eb34d Mon Sep 17 00:00:00 2001 From: rdimaio Date: Tue, 13 Feb 2024 11:41:27 +0100 Subject: [PATCH] Refactor to use TokenDict dataclass and add type hints; #6454 --- lib/rucio/common/types.py | 8 ++++++ lib/rucio/common/utils.py | 2 +- lib/rucio/core/account.py | 36 +++++++++++------------- lib/rucio/core/account_limit.py | 10 +++---- lib/rucio/core/authentication.py | 42 +++++++++++++--------------- lib/rucio/core/oidc.py | 48 +++++++++++++++----------------- 6 files changed, 72 insertions(+), 74 deletions(-) diff --git a/lib/rucio/common/types.py b/lib/rucio/common/types.py index c3b0628dd8a..9baa5306c16 100644 --- a/lib/rucio/common/types.py +++ b/lib/rucio/common/types.py @@ -14,6 +14,8 @@ # limitations under the License. from typing import Any, Callable, Optional, TypedDict, Union +from datetime import datetime +from dataclasses import dataclass class InternalType(object): @@ -159,3 +161,9 @@ class RSESettingsDict(TypedDict): deterministic: bool domain: list[str] protocols: list[RSEProtocolDict] + + +@dataclass +class TokenDict: + token: str + expires_at: datetime diff --git a/lib/rucio/common/utils.py b/lib/rucio/common/utils.py index 34e7b8b0130..8d015be88ba 100644 --- a/lib/rucio/common/utils.py +++ b/lib/rucio/common/utils.py @@ -570,7 +570,7 @@ def str_to_date(string): return datetime.datetime.strptime(string, DATE_FORMAT) if string else None -def val_to_space_sep_str(vallist): +def val_to_space_sep_str(vallist: list) -> str: """ Converts a list of values into a string of space separated values :param vallist: the list of values to to convert into string diff --git a/lib/rucio/core/account.py b/lib/rucio/core/account.py index 2c05ff6dc8f..f05b76f3737 100644 --- a/lib/rucio/core/account.py +++ b/lib/rucio/core/account.py @@ -17,8 +17,7 @@ from enum import Enum from re import match from traceback import format_exc -from typing import TYPE_CHECKING, Any -from collections.abc import Generator +from typing import TYPE_CHECKING, Any, Iterator, Optional import uuid from sqlalchemy import select, and_ @@ -84,7 +83,7 @@ def account_exists(account: InternalAccount, *, session: "Session") -> bool: @read_session -def get_account(account: InternalAccount, *, session: "Session") -> dict: +def get_account(account: InternalAccount, *, session: "Session") -> models.Account: """ Returns an account for the given account name. :param account: the name of the account. @@ -118,12 +117,11 @@ def del_account(account: InternalAccount, *, session: "Session"): models.Account.status == AccountStatus.ACTIVE ) try: - account = session.execute(query).scalar_one() + account_result = session.execute(query).scalar_one() + account_result.update({'status': AccountStatus.DELETED, 'deleted_at': datetime.utcnow()}) except exc.NoResultFound: raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account) - account.update({'status': AccountStatus.DELETED, 'deleted_at': datetime.utcnow()}) - @transactional_session def update_account(account: InternalAccount, key: str, value: Any, *, session: "Session"): @@ -140,22 +138,22 @@ def update_account(account: InternalAccount, key: str, value: Any, *, session: " models.Account.account == account ) try: - account = session.execute(query).scalar_one() + account_result = session.execute(query).scalar_one() + if key == 'status': + if isinstance(value, str): + value = AccountStatus[value] + if value == AccountStatus.SUSPENDED: + account_result.update({'status': value, 'suspended_at': datetime.utcnow()}) + elif value == AccountStatus.ACTIVE: + account_result.update({'status': value, 'suspended_at': None}) + else: + account_result.update({key: value}) except exc.NoResultFound: raise exception.AccountNotFound('Account with ID \'%s\' cannot be found' % account) - if key == 'status': - if isinstance(value, str): - value = AccountStatus[value] - if value == AccountStatus.SUSPENDED: - account.update({'status': value, 'suspended_at': datetime.utcnow()}) - elif value == AccountStatus.ACTIVE: - account.update({'status': value, 'suspended_at': None}) - else: - account.update({key: value}) @stream_session -def list_accounts(filter_: dict = None, *, session: "Session") -> Generator[dict]: +def list_accounts(filter_: Optional[dict] = None, *, session: "Session") -> Iterator[dict]: """ Returns a list of all account names. :param filter_: Dictionary of attributes by which the input data should be filtered @@ -403,14 +401,14 @@ def get_all_rse_usages_per_account(account: InternalAccount, *, session: "Sessio @read_session -def get_usage_history(rse_id: uuid.UUID, account: InternalAccount, *, session: "Session") -> dict: +def get_usage_history(rse_id: uuid.UUID, account: InternalAccount, *, session: "Session") -> list[dict]: """ Returns historical values of the specified counter, or raises CounterNotFound if the counter does not exist. :param rse_id: The id of the RSE. :param account: The account name. :param session: The database session in use. - :returns: A dictionary {'bytes', 'files', 'updated_at'} + :returns: A list of dictionaries {'bytes', 'files', 'updated_at'} """ query = select( models.AccountUsageHistory.bytes, diff --git a/lib/rucio/core/account_limit.py b/lib/rucio/core/account_limit.py index b1e4be592fa..1d9b9aabdf2 100644 --- a/lib/rucio/core/account_limit.py +++ b/lib/rucio/core/account_limit.py @@ -92,7 +92,7 @@ def get_rse_account_usage(rse_id: uuid.UUID, *, session: "Session") -> list[dict @read_session -def get_global_account_limits(account: InternalAccount = None, *, session: "Session") -> dict: +def get_global_account_limits(account: Optional[InternalAccount] = None, *, session: "Session") -> dict: """ Returns the global account limits for the account. @@ -123,7 +123,7 @@ def get_global_account_limits(account: InternalAccount = None, *, session: "Sess @read_session -def get_global_account_limit(account: InternalAccount, rse_expression: str, *, session: "Session") -> int: +def get_global_account_limit(account: InternalAccount, rse_expression: str, *, session: "Session") -> Optional[int | float]: """ Returns the global account limit for the account on the rse expression. @@ -164,7 +164,7 @@ def get_local_account_limit(account: InternalAccount, rse_id: uuid.UUID, *, sess @read_session -def get_local_account_limits(account: InternalAccount, rse_ids: list[uuid.UUID] = None, *, session: "Session") -> dict: +def get_local_account_limits(account: InternalAccount, rse_ids: Optional[list[uuid.UUID]] = None, *, session: "Session") -> dict: """ Returns the account limits for the account on the list of rses. @@ -270,7 +270,7 @@ def delete_global_account_limit(account: InternalAccount, rse_expression: str, * @transactional_session -def get_local_account_usage(account: InternalAccount, rse_id: uuid.UUID = None, *, session: "Session") -> list[dict]: +def get_local_account_usage(account: InternalAccount, rse_id: Optional[uuid.UUID] = None, *, session: "Session") -> list[dict]: """ Read the account usage and connect it with (if available) the account limits of the account. @@ -313,7 +313,7 @@ def get_local_account_usage(account: InternalAccount, rse_id: uuid.UUID = None, @transactional_session -def get_global_account_usage(account: InternalAccount, rse_expression: str = None, *, session: "Session") -> list[dict]: +def get_global_account_usage(account: InternalAccount, rse_expression: Optional[str] = None, *, session: "Session") -> list[dict]: """ Read the account usage and connect it with the global account limits of the account. diff --git a/lib/rucio/core/authentication.py b/lib/rucio/core/authentication.py index cd5775ea390..44151c7ad38 100644 --- a/lib/rucio/core/authentication.py +++ b/lib/rucio/core/authentication.py @@ -20,7 +20,7 @@ import sys import traceback from base64 import b64decode -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import paramiko from dogpile.cache import make_region @@ -30,6 +30,7 @@ from rucio.common.cache import make_region_memcached from rucio.common.config import config_get_bool from rucio.common.exception import CannotAuthenticate, RucioException +from rucio.common.types import InternalAccount, TokenDict from rucio.common.utils import chunks, generate_uuid, date_to_str from rucio.core.account import account_exists from rucio.core.oidc import validate_jwt @@ -90,7 +91,7 @@ def generate_key(token, *, session: "Session"): @transactional_session -def get_auth_token_user_pass(account, username, password, appid, ip=None, *, session: "Session"): +def get_auth_token_user_pass(account: InternalAccount, username: str, password: str, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]: """ Authenticate a Rucio account temporarily via username and password. @@ -146,11 +147,11 @@ def get_auth_token_user_pass(account, username, password, appid, ip=None, *, ses new_token = models.Token(account=db_account, identity=username, token=token, ip=ip) new_token.save(session=session) - return token_dictionary(new_token) + return TokenDict(new_token.token, new_token.expired_at) @transactional_session -def get_auth_token_x509(account, dn, appid, ip=None, *, session: "Session"): +def get_auth_token_x509(account: InternalAccount, dn: str, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]: """ Authenticate a Rucio account temporarily via an x509 certificate. @@ -178,11 +179,11 @@ def get_auth_token_x509(account, dn, appid, ip=None, *, session: "Session"): new_token = models.Token(account=account, identity=dn, token=token, ip=ip) new_token.save(session=session) - return token_dictionary(new_token) + return TokenDict(new_token.token, new_token.expired_at) @transactional_session -def get_auth_token_gss(account, gsstoken, appid, ip=None, *, session: "Session"): +def get_auth_token_gss(account: InternalAccount, gsstoken: str, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]: """ Authenticate a Rucio account temporarily via a GSS token. @@ -210,11 +211,11 @@ def get_auth_token_gss(account, gsstoken, appid, ip=None, *, session: "Session") new_token = models.Token(account=account, token=token, ip=ip) new_token.save(session=session) - return token_dictionary(new_token) + return TokenDict(new_token.token, new_token.expired_at) @transactional_session -def get_auth_token_ssh(account, signature, appid, ip=None, *, session: "Session"): +def get_auth_token_ssh(account: InternalAccount, signature: str | bytes, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]: """ Authenticate a Rucio account temporarily via SSH key exchange. @@ -228,7 +229,7 @@ def get_auth_token_ssh(account, signature, appid, ip=None, *, session: "Session" :returns: A dict with token and expires_at entries. """ - if not isinstance(signature, bytes): + if isinstance(signature, str): signature = signature.encode() # Make sure the account exists @@ -284,18 +285,17 @@ def get_auth_token_ssh(account, signature, appid, ip=None, *, session: "Session" new_token = models.Token(account=account, token=token, ip=ip) new_token.save(session=session) - return token_dictionary(new_token) + return TokenDict(new_token.token, new_token.expired_at) @transactional_session -def get_ssh_challenge_token(account, appid, ip=None, *, session: "Session"): +def get_ssh_challenge_token(account: InternalAccount, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]: """ Prepare a challenge token for subsequent SSH public key authentication. The challenge lifetime is fixed to 10 seconds. :param account: Account identifier as a string. - :param appid: The application identifier as a string. :param ip: IP address of the client as a string. :returns: A dict with token and expires_at entries. @@ -320,11 +320,11 @@ def get_ssh_challenge_token(account, appid, ip=None, *, session: "Session"): expired_at=expiration) new_challenge_token.save(session=session) - return token_dictionary(new_challenge_token) + return TokenDict(new_challenge_token.token, new_challenge_token.expired_at) @transactional_session -def get_auth_token_saml(account, saml_nameid, appid, ip=None, *, session: "Session"): +def get_auth_token_saml(account: InternalAccount, saml_nameid: str, appid: str, ip: Optional[str] = None, *, session: "Session") -> Optional[TokenDict]: """ Authenticate a Rucio account temporarily via SAML. @@ -351,11 +351,11 @@ def get_auth_token_saml(account, saml_nameid, appid, ip=None, *, session: "Sessi new_token = models.Token(account=account, identity=saml_nameid, token=token, ip=ip) new_token.save(session=session) - return token_dictionary(new_token) + return TokenDict(new_token.token, new_token.expired_at) @transactional_session -def redirect_auth_oidc(auth_code, fetchtoken=False, *, session: "Session"): +def redirect_auth_oidc(auth_code: str, fetchtoken: bool = False, *, session: "Session") -> Optional[str]: """ Finds the Authentication URL in the Rucio DB oauth_requests table and redirects user's browser to this URL. @@ -396,7 +396,7 @@ def redirect_auth_oidc(auth_code, fetchtoken=False, *, session: "Session"): @transactional_session -def delete_expired_tokens(total_workers, worker_number, limit=1000, *, session: "Session"): +def delete_expired_tokens(total_workers: int, worker_number: int, limit: int = 1000, *, session: "Session") -> int: """ Delete expired tokens. @@ -456,7 +456,7 @@ def delete_expired_tokens(total_workers, worker_number, limit=1000, *, session: @read_session -def query_token(token, *, session: "Session"): +def query_token(token: str, *, session: "Session") -> Optional[dict]: """ Validate an authentication token using the database. This method will only be called if no entry could be found in the according cache. @@ -530,12 +530,8 @@ def validate_auth_token(token: str, *, session: "Session") -> "dict[str, Any]": return value -def token_dictionary(token: models.Token): - return {'token': token.token, 'expires_at': token.expired_at} - - @transactional_session -def __delete_expired_tokens_account(account, *, session: "Session"): +def __delete_expired_tokens_account(account: InternalAccount, *, session: "Session"): """" Deletes expired tokens from the database. diff --git a/lib/rucio/core/oidc.py b/lib/rucio/core/oidc.py index 38a7066c2eb..5b58de50bd0 100644 --- a/lib/rucio/core/oidc.py +++ b/lib/rucio/core/oidc.py @@ -38,12 +38,12 @@ from sqlalchemy import delete, select, update from sqlalchemy.sql.expression import true -from rucio.common import types from rucio.common.cache import make_region_memcached from rucio.common.config import config_get, config_get_int from rucio.common.exception import (CannotAuthenticate, CannotAuthorize, RucioException) from rucio.common.stopwatch import Stopwatch +from rucio.common.types import InternalAccount, TokenDict from rucio.common.utils import all_oidc_req_claims_present, build_url, val_to_space_sep_str from rucio.core.account import account_exists from rucio.core.identity import exist_identity_account, get_default_account @@ -469,7 +469,7 @@ def get_auth_oidc(account: str, *, session: "Session", **kwargs) -> str: @transactional_session -def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session"): +def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session") -> Optional[dict]: """ After Rucio User got redirected to Rucio /auth/oidc_token (or /auth/oidc_code) REST endpoints with authz code and session state encoded within the URL. @@ -479,8 +479,8 @@ def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session" :param ip: IP address of the client as a string. :param session: The database session in use. - :returns: One of the following tuples: ("fetchcode", ); ("token", ); - ("polling", True); The result depends on the authentication strategy being used + :returns: One of the following dicts: {"fetchcode": }; {"token": }; + {"polling": True}; The result depends on the authentication strategy being used (no auto, auto, polling). """ try: @@ -630,13 +630,13 @@ def get_token_oidc(auth_query_string: str, ip: str = None, *, session: "Session" @transactional_session -def __get_admin_token_oidc(account: types.InternalAccount, req_scope, req_audience, issuer, *, session: "Session"): +def __get_admin_token_oidc(account: InternalAccount, req_scope: str, req_audience: str, issuer: str, *, session: "Session") -> Optional[TokenDict]: """ Get a token for Rucio application to act on behalf of itself. client_credential flow is used for this purpose. No refresh token is expected to be used. - :param account: the Rucio Admin account name to be used (InternalAccount object expected) + :param account: the Rucio Admin account name to be used :param req_scope: the audience requested for the Rucio client's token :param req_audience: the scope requested for the Rucio client's token :param issuer: the Identity Provider nickname or the Rucio instance in use @@ -689,9 +689,9 @@ def __get_admin_token_oidc(account: types.InternalAccount, req_scope, req_audien @read_session -def __get_admin_account_for_issuer(*, session: "Session"): +def __get_admin_account_for_issuer(*, session: "Session") -> dict: """ Gets admin account for the IdP issuer - :returns : dictionary { 'issuer_1': (account, identity), ... } + :returns: dictionary { 'issuer_1': (account, identity), ... } """ if not OIDC_ADMIN_CLIENTS: @@ -715,7 +715,7 @@ def __get_admin_account_for_issuer(*, session: "Session"): @transactional_session -def get_token_for_account_operation(account: str, req_audience: str = None, req_scope: str = None, admin: bool = False, *, session: "Session"): +def get_token_for_account_operation(account: str, req_audience: str = None, req_scope: str = None, admin: bool = False, *, session: "Session") -> Optional[TokenDict]: """ Looks-up a JWT token with the required scope and audience claims with the account OIDC issuer. If tokens are found, and none contains the requested audience and scope a new token is requested @@ -801,7 +801,7 @@ def get_token_for_account_operation(account: str, req_audience: str = None, req_ for admin_token in admin_account_tokens: if hasattr(admin_token, 'audience') and hasattr(admin_token, 'oidc_scope') and\ all_oidc_req_claims_present(admin_token.oidc_scope, admin_token.audience, req_scope, req_audience): - return token_dictionary(admin_token) + return TokenDict(admin_token.token, admin_token.expired_at) # if not found request a new one new_admin_token = __get_admin_token_oidc(account, req_scope, req_audience, admin_issuer, session=session) return new_admin_token @@ -845,7 +845,7 @@ def get_token_for_account_operation(account: str, req_audience: str = None, req_ for admin_token in admin_account_tokens: if hasattr(admin_token, 'audience') and hasattr(admin_token, 'oidc_scope') and\ all_oidc_req_claims_present(admin_token.oidc_scope, admin_token.audience, req_scope, req_audience): - return token_dictionary(admin_token) + return TokenDict(admin_token.token, admin_token.expired_at) # if no admin token existing was found for the issuer of the valid user token # we request a new one new_admin_token = __get_admin_token_oidc(admin_account, req_scope, req_audience, admin_issuer, session=session) @@ -868,7 +868,7 @@ def get_token_for_account_operation(account: str, req_audience: str = None, req_ for token in account_tokens: if hasattr(token, 'audience') and hasattr(token, 'oidc_scope'): if all_oidc_req_claims_present(token.oidc_scope, token.audience, req_scope, req_audience): - return token_dictionary(token) + return TokenDict(token.token, token.expired_at) # from available tokens select preferentially the one which are being refreshed if hasattr(token, 'oidc_scope') and ('offline_access' in str(token['oidc_scope'])): subject_token = token @@ -1000,7 +1000,7 @@ def __change_refresh_state(token: str, refresh: bool = False, *, session: "Sessi @transactional_session -def refresh_cli_auth_token(token_string: str, account: str, *, session: "Session"): +def refresh_cli_auth_token(token_string: str, account: str, *, session: "Session") -> Optional[tuple]: """ Checks if there is active refresh token and if so returns either active token with expiration timestamp or requests a new @@ -1079,7 +1079,7 @@ def refresh_cli_auth_token(token_string: str, account: str, *, session: "Session @transactional_session -def refresh_jwt_tokens(total_workers: int, worker_number: int, refreshrate: int = 3600, limit: int = 1000, *, session: "Session"): +def refresh_jwt_tokens(total_workers: int, worker_number: int, refreshrate: int = 3600, limit: int = 1000, *, session: "Session") -> int: """ Refreshes tokens which expired or will expire before (now + refreshrate) next run of this function and which have valid refresh token. @@ -1089,7 +1089,7 @@ def refresh_jwt_tokens(total_workers: int, worker_number: int, refreshrate: int :param limit: Maximum number of tokens to refresh per call. :param session: Database session in use. - :return: numper of tokens refreshed + :return: number of tokens refreshed """ nrefreshed = 0 try: @@ -1135,7 +1135,7 @@ def refresh_jwt_tokens(total_workers: int, worker_number: int, refreshrate: int @METRICS.time_it @transactional_session -def __refresh_token_oidc(token_object: models.Token, *, session: "Session"): +def __refresh_token_oidc(token_object: models.Token, *, session: "Session") -> Optional[TokenDict]: """ Requests new access and refresh tokens from the Identity Provider. Assumption: The Identity Provider issues refresh tokens for one time use only and @@ -1212,7 +1212,7 @@ def __refresh_token_oidc(token_object: models.Token, *, session: "Session"): @transactional_session -def delete_expired_oauthrequests(total_workers: int, worker_number: int, limit: int = 1000, *, session: "Session"): +def delete_expired_oauthrequests(total_workers: int, worker_number: int, limit: int = 1000, *, session: "Session") -> int: """ Delete expired OAuth request parameters. @@ -1266,7 +1266,7 @@ def __get_keyvalues_from_claims(token: str, keys=None): """ Extracting claims from token, e.g. scope and audience. :param token: the JWT to be unpacked - :param key: list of key names to extract from the token claims + :param keys: list of key names to extract from the token claims :returns: The list of unicode values under the key, throws an exception otherwise. """ @@ -1286,7 +1286,7 @@ def __get_keyvalues_from_claims(token: str, keys=None): @read_session -def __get_rucio_jwt_dict(jwt: str, account=None, *, session: "Session"): +def __get_rucio_jwt_dict(jwt: str, account: Optional[InternalAccount] = None, *, session: "Session") -> Optional[dict]: """ Get a Rucio token dictionary from token claims. Check token expiration and find default Rucio @@ -1330,7 +1330,7 @@ def __get_rucio_jwt_dict(jwt: str, account=None, *, session: "Session"): @transactional_session -def __save_validated_token(token, valid_dict, extra_dict=None, *, session: "Session"): +def __save_validated_token(token: str, valid_dict: dict, extra_dict: Optional[dict] = None, *, session: "Session") -> TokenDict: """ Save JWT token to the Rucio DB. @@ -1357,7 +1357,7 @@ def __save_validated_token(token, valid_dict, extra_dict=None, *, session: "Sess ip=extra_dict.get('ip', None)) new_token.save(session=session) - return token_dictionary(new_token) + return TokenDict(new_token.token, new_token.expired_at) except Exception as error: raise RucioException(error.args) from error @@ -1433,7 +1433,7 @@ def validate_jwt(json_web_token: str, *, session: "Session") -> dict[str, Any]: raise CannotAuthenticate(traceback.format_exc()) -def oidc_identity_string(sub: str, iss: str): +def oidc_identity_string(sub: str, iss: str) -> str: """ Transform IdP sub claim and issuers url into users identity string. :param sub: users SUB claim from the Identity Provider @@ -1442,7 +1442,3 @@ def oidc_identity_string(sub: str, iss: str): :returns: OIDC identity string "SUB=, ISS=https://iam-test.ch/" """ return 'SUB=' + str(sub) + ', ISS=' + str(iss) - - -def token_dictionary(token: models.Token): - return {'token': token.token, 'expires_at': token.expired_at}