diff --git a/.gitignore b/.gitignore index 6428421..737d3b8 100644 --- a/.gitignore +++ b/.gitignore @@ -114,4 +114,5 @@ dmypy.json # Pyre type checker .pyre/ -# End of https://www.gitignore.io/api/python +AGENTS.md +PROMPTS.md diff --git a/CHANGELOG.md b/CHANGELOG.md index ba4dc2c..dfee6a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,12 @@ # Changelog ## [Unreleased] +### Added +- OAuth 2.0 Password Grant authentication, by @HardNorth + +## [5.6.5] +### Removed +- Getting message on launch finish, since there is no such field in this response, by @HardNorth ## [5.6.4] ### Added diff --git a/pyproject.toml b/pyproject.toml index 25f7e34..3086587 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,3 +15,9 @@ profile = "black" [tool.black] line-length = 119 target-version = ["py310"] + +[tool.pytest.ini_options] +minversion = "6.0" +required_plugins = "pytest-cov" +testpaths = ["tests"] +asyncio_default_fixture_loop_scope = "session" diff --git a/reportportal_client/__init__.py b/reportportal_client/__init__.py index 2f2c3a2..1d2be53 100644 --- a/reportportal_client/__init__.py +++ b/reportportal_client/__init__.py @@ -12,8 +12,15 @@ # limitations under the License """This package is the base package for ReportPortal client.""" -import typing -import warnings + +import sys +from typing import Optional, Tuple, TypedDict, Union + +# noinspection PyUnreachableCode +if sys.version_info >= (3, 11): + from typing import Unpack +else: + from typing_extensions import Unpack import aenum @@ -34,74 +41,94 @@ class ClientType(aenum.Enum): ASYNC_BATCHED = aenum.auto() +class _ClientOptions(TypedDict, total=False): + client_type: ClientType + endpoint: str + project: str + api_key: Optional[str] + # OAuth 2.0 parameters + oauth_uri: Optional[str] + oauth_username: Optional[str] + oauth_password: Optional[str] + oauth_client_id: Optional[str] + oauth_client_secret: Optional[str] + oauth_scope: Optional[str] + # Common client parameters + launch_uuid: Optional[str] + is_skipped_an_issue: bool + verify_ssl: Union[bool, str] + retries: int + max_pool_size: int + http_timeout: Union[float, Tuple[float, float]] + mode: str + launch_uuid_print: bool + print_output: OutputType + truncate_attributes: bool + log_batch_size: int + log_batch_payload_limit: int + # Async client specific parameters + keepalive_timeout: float + # Async threaded/batched client specific parameters + task_timeout: float + shutdown_timeout: float + # Async batched client specific parameters + trigger_num: int + trigger_interval: float + + # noinspection PyIncorrectDocstring def create_client( - client_type: ClientType, endpoint: str, project: str, *, api_key: str = None, **kwargs: typing.Any -) -> typing.Optional[RP]: + client_type: ClientType, endpoint: str, project: str, **kwargs: Unpack[_ClientOptions] +) -> Optional[RP]: """Create and ReportPortal Client based on the type and arguments provided. :param client_type: Type of the Client to create. - :type client_type: ClientType :param endpoint: Endpoint of the ReportPortal service. - :type endpoint: str :param project: Project name to report to. - :type project: str :param api_key: Authorization API key. - :type api_key: str + :param oauth_uri: OAuth 2.0 token endpoint URI (for OAuth authentication). + :param oauth_username: Username for OAuth 2.0 authentication. + :param oauth_password: Password for OAuth 2.0 authentication. + :param oauth_client_id: OAuth 2.0 client ID. + :param oauth_client_secret: OAuth 2.0 client secret (optional). + :param oauth_scope: OAuth 2.0 scope (optional). :param launch_uuid: A launch UUID to use instead of starting own one. - :type launch_uuid: str :param is_skipped_an_issue: Option to mark skipped tests as not 'To Investigate' items on the server side. - :type is_skipped_an_issue: bool :param verify_ssl: Option to skip ssl verification. - :type verify_ssl: typing.Union[bool, str] :param retries: Number of retry attempts to make in case of connection / server errors. - :type retries: int :param max_pool_size: Option to set the maximum number of connections to save the pool. - :type max_pool_size: int :param http_timeout : A float in seconds for connect and read timeout. Use a Tuple to specific connect and read separately. - :type http_timeout: Tuple[float, float] :param mode: Launch mode, all Launches started by the client will be in that mode. - :type mode: str :param launch_uuid_print: Print Launch UUID into passed TextIO or by default to stdout. - :type launch_uuid_print: bool :param print_output: Set output stream for Launch UUID printing. - :type print_output: OutputType :param truncate_attributes: Truncate test item attributes to default maximum length. - :type truncate_attributes: bool :param log_batch_size: Option to set the maximum number of logs that can be processed in one batch. - :type log_batch_size: int :param log_batch_payload_limit: Maximum size in bytes of logs that can be processed in one batch. - :type log_batch_payload_limit: int :param keepalive_timeout: For Async Clients only. Maximum amount of idle time in seconds before force connection closing. - :type keepalive_timeout: int :param task_timeout: For Async Threaded and Batched Clients only. Time limit in seconds for a Task processing. - :type task_timeout: float :param shutdown_timeout: For Async Threaded and Batched Clients only. Time limit in seconds for shutting down internal Tasks. - :type shutdown_timeout: float :param trigger_num: For Async Batched Client only. Number of tasks which triggers Task batch execution. - :type trigger_num: int :param trigger_interval: For Async Batched Client only. Time limit which triggers Task batch execution. - :type trigger_interval: float :return: ReportPortal Client instance. """ if client_type is ClientType.SYNC: - return RPClient(endpoint, project, api_key=api_key, **kwargs) + return RPClient(endpoint, project, **kwargs) if client_type is ClientType.ASYNC: - return AsyncRPClient(endpoint, project, api_key=api_key, **kwargs) + return AsyncRPClient(endpoint, project, **kwargs) if client_type is ClientType.ASYNC_THREAD: - return ThreadedRPClient(endpoint, project, api_key=api_key, **kwargs) + return ThreadedRPClient(endpoint, project, **kwargs) if client_type is ClientType.ASYNC_BATCHED: - return BatchedRPClient(endpoint, project, api_key=api_key, **kwargs) - warnings.warn(f"Unknown ReportPortal Client type requested: {client_type}", RuntimeWarning, stacklevel=2) + return BatchedRPClient(endpoint, project, **kwargs) + raise ValueError(f"Unknown ReportPortal Client type requested: {client_type}") __all__ = [ diff --git a/reportportal_client/_internal/aio/http.py b/reportportal_client/_internal/aio/http.py index 7dc7b69..e969239 100644 --- a/reportportal_client/_internal/aio/http.py +++ b/reportportal_client/_internal/aio/http.py @@ -24,15 +24,20 @@ import asyncio import sys from types import TracebackType -from typing import Any, Callable, Coroutine, Optional, Type +from typing import Any, Callable, Coroutine, Optional, Type, Union from aenum import Enum -from aiohttp import ClientResponse, ClientResponseError, ClientSession, ServerConnectionError +from aiohttp import ClientResponse, ClientResponseError +from aiohttp import ClientSession as AioHttpClientSession +from aiohttp import ServerConnectionError + +from reportportal_client._internal.services.auth import AuthAsync DEFAULT_RETRY_NUMBER: int = 5 DEFAULT_RETRY_DELAY: float = 0.005 THROTTLING_STATUSES: set = {425, 429} RETRY_STATUSES: set = {408, 500, 502, 503, 507}.union(THROTTLING_STATUSES) +AUTH_PROBLEM_STATUSES: set = {401, 403} class RetryClass(int, Enum): @@ -46,7 +51,7 @@ class RetryClass(int, Enum): class RetryingClientSession: """Class uses aiohttp.ClientSession.request method and adds request retry logic.""" - _client: ClientSession + _client: AioHttpClientSession __retry_number: int __retry_delay: float @@ -68,7 +73,7 @@ def __init__( an error. Real value highly depends on Retry Class and Retry attempt number, since retries are performed in exponential delay manner """ - self._client = ClientSession(*args, **kwargs) + self._client = AioHttpClientSession(*args, **kwargs) self.__retry_number = max_retry_number self.__retry_delay = base_retry_delay @@ -91,8 +96,12 @@ async def __request(self, method: Callable, url, **kwargs: Any) -> ClientRespons """ result = None exceptions = [] + for i in range(self.__retry_number + 1): # add one for the first attempt, which is not a retry retry_factor = None + if result is not None: + # Release previous result to return connection to pool + await result.release() try: result = await method(url, **kwargs) except Exception as exc: @@ -157,3 +166,84 @@ async def __aexit__( ) -> None: """Auxiliary method which controls what `async with` construction does on block exit.""" await self.close() + + +class ClientSession: + """Class wraps aiohttp.ClientSession or RetryingClientSession and adds authentication support.""" + + _client: Union[AioHttpClientSession, RetryingClientSession] + __auth: Optional[AuthAsync] + + def __init__( + self, + wrapped: Union[AioHttpClientSession, RetryingClientSession], + auth: Optional[AuthAsync] = None, + ): + """Initialize an instance of the session with arguments. + + :param wrapped: aiohttp.ClientSession or RetryingClientSession instance to wrap + :param auth: authentication instance to use for requests + """ + self._client = wrapped + self.__auth = auth + + async def __request(self, method: Callable, url: str, **kwargs: Any) -> ClientResponse: + """Make a request with authentication support. + + The method adds Authorization header if auth is configured and handles auth refresh + on 401/403 responses. + """ + # Clone kwargs and add Authorization header if auth is configured + request_kwargs = kwargs.copy() + if self.__auth: + auth_header = await self.__auth.get() + if auth_header: + if "headers" not in request_kwargs: + request_kwargs["headers"] = {} + else: + request_kwargs["headers"] = request_kwargs["headers"].copy() + request_kwargs["headers"]["Authorization"] = auth_header + + result = await method(url, **request_kwargs) + + # Check for authentication errors + if result.status in AUTH_PROBLEM_STATUSES and self.__auth: + refreshed_header = await self.__auth.refresh() + if refreshed_header: + # Release previous result to return connection to pool + await result.release() + # Retry with new auth header + request_kwargs["headers"] = request_kwargs.get("headers", {}).copy() + request_kwargs["headers"]["Authorization"] = refreshed_header + result = await method(url, **request_kwargs) + + return result + + def get(self, url: str, *, allow_redirects: bool = True, **kwargs: Any) -> Coroutine[Any, Any, ClientResponse]: + """Perform HTTP GET request.""" + return self.__request(self._client.get, url, allow_redirects=allow_redirects, **kwargs) + + def post(self, url: str, *, data: Any = None, **kwargs: Any) -> Coroutine[Any, Any, ClientResponse]: + """Perform HTTP POST request.""" + return self.__request(self._client.post, url, data=data, **kwargs) + + def put(self, url: str, *, data: Any = None, **kwargs: Any) -> Coroutine[Any, Any, ClientResponse]: + """Perform HTTP PUT request.""" + return self.__request(self._client.put, url, data=data, **kwargs) + + def close(self) -> Coroutine: + """Gracefully close internal session instance.""" + return self._client.close() + + async def __aenter__(self) -> "ClientSession": + """Auxiliary method which controls what `async with` construction does on block enter.""" + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """Auxiliary method which controls what `async with` construction does on block exit.""" + await self.close() diff --git a/reportportal_client/_internal/aio/tasks.py b/reportportal_client/_internal/aio/tasks.py index fea52c0..00cbfbb 100644 --- a/reportportal_client/_internal/aio/tasks.py +++ b/reportportal_client/_internal/aio/tasks.py @@ -177,20 +177,24 @@ def append(self, value: _T) -> Optional[List[_T]]: :return: a batch or None """ self.__task_list.append(value) - if self.__ready_to_run(): - tasks = self.__task_list - self.__task_list = [] - return tasks + if not self.__ready_to_run(): + return None + + tasks = self.__task_list + self.__task_list = [] + return tasks def flush(self) -> Optional[List[_T]]: """Immediately return everything what's left in the internal batch. :return: a batch or None """ - if len(self.__task_list) > 0: - tasks = self.__task_list - self.__task_list = [] - return tasks + if len(self.__task_list) <= 0: + return None + + tasks = self.__task_list + self.__task_list = [] + return tasks class BackgroundTaskList(Generic[_T]): @@ -224,7 +228,9 @@ def flush(self) -> Optional[List[_T]]: :return: a batch or None """ self.__remove_finished() - if len(self.__task_list) > 0: - tasks = self.__task_list - self.__task_list = [] - return tasks + if len(self.__task_list) <= 0: + return None + + tasks = self.__task_list + self.__task_list = [] + return tasks diff --git a/reportportal_client/_internal/http.py b/reportportal_client/_internal/http.py new file mode 100644 index 0000000..06d1513 --- /dev/null +++ b/reportportal_client/_internal/http.py @@ -0,0 +1,112 @@ +# Copyright 2025 EPAM Systems +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +"""This module designed to help with synchronous HTTP request/response handling.""" + +from types import TracebackType +from typing import Any, Callable, Optional, Type, Union + +from requests import Response, Session +from requests.adapters import BaseAdapter + +from reportportal_client._internal.services.auth import Auth + +AUTH_PROBLEM_STATUSES: set = {401, 403} + + +class ClientSession: + """Class wraps requests.Session and adds authentication support.""" + + _client: Session + __auth: Optional[Auth] + + def __init__( + self, + auth: Optional[Auth] = None, + ): + """Initialize an instance of the session with arguments. + + :param auth: authentication instance to use for requests + """ + self._client = Session() + self.__auth = auth + + def __request(self, method: Callable, url: Union[str, bytes], **kwargs: Any) -> Response: + """Make a request with authentication support. + + The method adds Authorization header if auth is configured and handles auth refresh + on 401/403 responses. + """ + # Clone kwargs and add Authorization header if auth is configured + request_kwargs = kwargs.copy() + if self.__auth: + auth_header = self.__auth.get() + if auth_header: + if "headers" not in request_kwargs: + request_kwargs["headers"] = {} + else: + request_kwargs["headers"] = request_kwargs["headers"].copy() + request_kwargs["headers"]["Authorization"] = auth_header + + result = method(url, **request_kwargs) + + # Check for authentication errors + if result.status_code in AUTH_PROBLEM_STATUSES and self.__auth: + refreshed_header = self.__auth.refresh() + if refreshed_header: + # Close previous result if it's retried to release resources + result.close() + # Retry with new auth header + request_kwargs["headers"] = request_kwargs.get("headers", {}).copy() + request_kwargs["headers"]["Authorization"] = refreshed_header + result = method(url, **request_kwargs) + + return result + + def get(self, url: Union[str, bytes], **kwargs: Any) -> Response: + """Perform HTTP GET request.""" + return self.__request(self._client.get, url, **kwargs) + + def post(self, url: Union[str, bytes], **kwargs: Any) -> Response: + """Perform HTTP POST request.""" + return self.__request(self._client.post, url, **kwargs) + + def put(self, url: Union[str, bytes], **kwargs: Any) -> Response: + """Perform HTTP PUT request.""" + return self.__request(self._client.put, url, **kwargs) + + def mount(self, prefix: str, adapter: BaseAdapter) -> None: + """Mount an adapter to a specific URL prefix. + + :param prefix: URL prefix (e.g., 'http://', 'https://') + :param adapter: Adapter instance to mount + """ + self._client.mount(prefix, adapter) + + def close(self) -> None: + """Gracefully close internal requests.Session class instance.""" + self._client.close() + + def __enter__(self) -> "ClientSession": + """Auxiliary method which controls what `with` construction does on block enter.""" + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """Auxiliary method which controls what `with` construction does on block exit.""" + self.close() diff --git a/reportportal_client/_internal/local/__init__.py b/reportportal_client/_internal/local/__init__.py index 7438976..da4dabb 100644 --- a/reportportal_client/_internal/local/__init__.py +++ b/reportportal_client/_internal/local/__init__.py @@ -22,6 +22,7 @@ def current(): """Return current ReportPortal client.""" if hasattr(__INSTANCES, "current"): return __INSTANCES.current + return None def set_current(client): diff --git a/reportportal_client/_internal/logs/batcher.py b/reportportal_client/_internal/logs/batcher.py index b47b77e..fc7378c 100644 --- a/reportportal_client/_internal/logs/batcher.py +++ b/reportportal_client/_internal/logs/batcher.py @@ -50,7 +50,7 @@ def __init__(self, entry_num=MAX_LOG_BATCH_SIZE, payload_limit=MAX_LOG_BATCH_PAY self._batch = [] self._payload_size = 0 - def _append(self, size: int, log_req: T_co) -> Optional[List[T_co]]: + def _append(self, size: int, log_req: RPRequestLog) -> Optional[List[RPRequestLog]]: with self._lock: if self._payload_size + size >= self.payload_limit: if len(self._batch) > 0: @@ -60,11 +60,13 @@ def _append(self, size: int, log_req: T_co) -> Optional[List[T_co]]: return batch self._batch.append(log_req) self._payload_size += size - if len(self._batch) >= self.entry_num: - batch = self._batch - self._batch = [] - self._payload_size = 0 - return batch + if len(self._batch) < self.entry_num: + return None + + batch = self._batch + self._batch = [] + self._payload_size = 0 + return batch def append(self, log_req: RPRequestLog) -> Optional[List[RPRequestLog]]: """Add a log request object to internal batch and return the batch if it's full. @@ -87,12 +89,15 @@ def flush(self) -> Optional[List[T_co]]: :return: a batch or None """ + if len(self._batch) <= 0: + return None with self._lock: - if len(self._batch) > 0: - batch = self._batch - self._batch = [] - self._payload_size = 0 - return batch + if len(self._batch) <= 0: + return None + batch = self._batch + self._batch = [] + self._payload_size = 0 + return batch def __getstate__(self) -> Dict[str, Any]: """Control object pickling and return object fields as Dictionary. diff --git a/reportportal_client/_internal/services/auth.py b/reportportal_client/_internal/services/auth.py new file mode 100644 index 0000000..437ba62 --- /dev/null +++ b/reportportal_client/_internal/services/auth.py @@ -0,0 +1,521 @@ +# Copyright 2025 EPAM Systems +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module handles authentication for ReportPortal.""" + +import logging +import time +from abc import abstractmethod +from typing import Optional + +import aiohttp +import requests + +from reportportal_client._internal.static.abstract import AbstractBaseClass + +logger = logging.getLogger(__name__) + + +# noinspection PyAbstractClass +class Auth(metaclass=AbstractBaseClass): + """Abstract base class for synchronous authentication. + + This class defines the interface for all synchronous authentication methods. + """ + + __metaclass__ = AbstractBaseClass + + @abstractmethod + def get(self) -> Optional[str]: + """Get valid Authorization header value. + + :return: Authorization header value or None if authentication failed. + """ + raise NotImplementedError('"get" method is not implemented!') + + @abstractmethod + def refresh(self) -> Optional[str]: + """Refresh the access token and return Authorization header value. + + :return: Authorization header value or None if refresh failed. + """ + raise NotImplementedError('"refresh" method is not implemented!') + + +# noinspection PyAbstractClass +class AuthAsync(metaclass=AbstractBaseClass): + """Abstract base class for asynchronous authentication. + + This class defines the interface for all asynchronous authentication methods. + """ + + __metaclass__ = AbstractBaseClass + + @abstractmethod + async def get(self) -> Optional[str]: + """Get valid Authorization header value. + + :return: Authorization header value or None if authentication failed. + """ + raise NotImplementedError('"get" method is not implemented!') + + @abstractmethod + async def refresh(self) -> Optional[str]: + """Refresh the access token and return Authorization header value. + + :return: Authorization header value or None if refresh failed. + """ + raise NotImplementedError('"refresh" method is not implemented!') + + +class ApiKeyAuthSync(Auth): + """Synchronous API key authentication. + + This class provides simple key-based authentication that always returns + the provided API key. + """ + + api_key: str + + def __init__(self, api_key: str) -> None: + """Initialize API key authentication. + + :param api_key: API key for authentication. + """ + self.api_key = api_key + + def get(self) -> Optional[str]: + """Get valid Authorization header value. + + :return: Authorization header value with Bearer token. + """ + return f"Bearer {self.api_key}" + + def refresh(self) -> None: + """Refresh the access key and return Authorization header value. + + For API keys, this simply returns None as there's no refresh mechanism. + + :return: None + """ + return None + + +class ApiKeyAuthAsync(AuthAsync): + """Asynchronous API key authentication. + + This class provides simple key-based authentication that always returns + the provided API key. + """ + + api_key: str + + def __init__(self, api_key: str) -> None: + """Initialize API key authentication. + + :param api_key: API key for authentication. + """ + self.api_key = api_key + + async def get(self) -> Optional[str]: + """Get valid Authorization header value. + + :return: Authorization header value with Bearer token. + """ + return f"Bearer {self.api_key}" + + async def refresh(self) -> None: + """Refresh the access key and return Authorization header value. + + For API keys, this simply returns None as there's no refresh mechanism. + + :return: None + """ + return None + + +# noinspection PyAbstractClass +class OAuthPasswordGrant: + """Base class for OAuth 2.0 password grant authentication. + + This class provides common logic for obtaining and refreshing access tokens using + the OAuth 2.0 password grant flow. This class should not be used directly, use + OAuthPasswordGrantSync or OAuthPasswordGrantAsync instead. + """ + + oauth_uri: str + username: str + password: str + client_id: str + client_secret: Optional[str] + scope: Optional[str] + _access_token: Optional[str] + _refresh_token: Optional[str] + _token_expires_at: Optional[float] + _last_attempt_time: Optional[float] + + def __init__( + self, + oauth_uri: str, + username: str, + password: str, + client_id: str, + client_secret: Optional[str] = None, + scope: Optional[str] = None, + ) -> None: + """Initialize OAuth 2.0 password grant authentication. + + :param oauth_uri: OAuth 2.0 token endpoint URI. + :param username: Username for authentication. + :param password: Password for authentication. + :param client_id: OAuth client ID. + :param client_secret: Optional OAuth client secret. + :param scope: Optional OAuth scope. + """ + self.oauth_uri = oauth_uri + self.username = username + self.password = password + self.client_id = client_id + self.client_secret = client_secret + self.scope = scope + self._access_token = None + self._refresh_token = None + self._token_expires_at = None + self._last_attempt_time = None + + def _should_skip_request(self) -> bool: + """Check if token request should be skipped due to throttling. + + :return: True if request should be skipped, False otherwise. + """ + if self._last_attempt_time is None: + return False + current_time = time.time() + return int(current_time) == int(self._last_attempt_time) + + def _is_token_expired(self) -> bool: + """Check if the current access token is expired. + + :return: True if token is expired or not set, False otherwise. + """ + if not self._access_token or self._token_expires_at is None: + return True + return time.time() >= self._token_expires_at + + def _update_last_attempt_time(self) -> None: + """Update the last attempt time to current time.""" + self._last_attempt_time = time.time() + + def _clear_token(self) -> None: + """Clear the current access token.""" + self._access_token = None + self._token_expires_at = None + + def _parse_token_response(self, response_data: dict) -> bool: + """Parse OAuth token response and store tokens. + + :param response_data: Response JSON data from OAuth server. + :return: True if parsing was successful, False otherwise. + """ + try: + access_token = response_data.get("access_token") + if not access_token: + logger.warning("OAuth token response missing 'access_token' field") + return False + + self._access_token = access_token + self._refresh_token = response_data.get("refresh_token") + + expires_in = response_data.get("expires_in") + if expires_in: + # Set expiration time with 30 seconds buffer to avoid edge cases + self._token_expires_at = time.time() + int(expires_in) - 30 + else: + # If expires_in is not provided, assume token is valid for a reasonable time + self._token_expires_at = time.time() + 3600 # 1 hour default + + return True + except (ValueError, TypeError) as e: + logger.warning(f"Failed to parse OAuth token response: {e}") + return False + + def _build_token_request_data(self, grant_type: str, **extra_params) -> dict: + """Build request data for OAuth token request. + + :param grant_type: OAuth grant type. + :param extra_params: Additional parameters for the request. + :return: Dictionary with request data. + """ + data = {"grant_type": grant_type, "client_id": self.client_id} + + if self.client_secret: + data["client_secret"] = self.client_secret + + if self.scope: + data["scope"] = self.scope + + data.update(extra_params) + return data + + +class OAuthPasswordGrantSync(OAuthPasswordGrant, Auth): + """Synchronous implementation of OAuth 2.0 password grant authentication.""" + + _session: Optional[requests.Session] + + def __init__( + self, + oauth_uri: str, + username: str, + password: str, + client_id: str, + client_secret: Optional[str] = None, + scope: Optional[str] = None, + session: Optional[requests.Session] = None, + ) -> None: + """Initialize OAuth 2.0 password grant authentication. + + :param oauth_uri: OAuth 2.0 token endpoint URI. + :param username: Username for authentication. + :param password: Password for authentication. + :param client_id: OAuth client ID. + :param client_secret: Optional OAuth client secret. + :param scope: Optional OAuth scope. + :param session: Optional requests.Session instance to use. + """ + super().__init__(oauth_uri, username, password, client_id, client_secret, scope) + self._session = session + + def _get_session(self) -> requests.Session: + """Get or create requests.Session. + + :return: Session instance. + """ + if self._session is None: + self._session = requests.Session() + return self._session + + def _execute_token_request(self, data: dict) -> bool: + """Execute token request to OAuth server. + + :param data: Request data. + :return: True if request was successful, False otherwise. + """ + self._update_last_attempt_time() + + try: + session = self._get_session() + response = session.post( + self.oauth_uri, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) + + if not response.ok: + logger.warning(f"OAuth token request failed with status {response.status_code}") + return False + + return self._parse_token_response(response.json()) + + except requests.exceptions.RequestException as e: + logger.warning(f"OAuth token request failed: {e}") + return False + + def _obtain_new_token(self) -> bool: + """Obtain new access token using password grant. + + :return: True if token was obtained successfully, False otherwise. + """ + data = self._build_token_request_data( + grant_type="password", + username=self.username, + password=self.password, + ) + return self._execute_token_request(data) + + def _refresh_access_token(self) -> bool: + """Refresh access token using refresh token. + + :return: True if token was refreshed successfully, False otherwise. + """ + if not self._refresh_token: + return False + + data = self._build_token_request_data( + grant_type="refresh_token", + refresh_token=self._refresh_token, + ) + return self._execute_token_request(data) + + def get(self) -> Optional[str]: + """Get valid Authorization header value. + + :return: Authorization header value or None if authentication failed. + """ + # If token is valid, return it (no need to check throttling for cached token) + if not self._is_token_expired(): + return f"Bearer {self._access_token}" + + # Check if we should skip new request due to throttling + if self._should_skip_request(): + return None + + # Try to refresh token first + if self._refresh_access_token(): + return f"Bearer {self._access_token}" + + # If refresh failed, try to obtain new token + if self._obtain_new_token(): + return f"Bearer {self._access_token}" + + return None + + def refresh(self) -> Optional[str]: + """Refresh the access token and return Authorization header value. + + :return: Authorization header value or None if refresh failed. + """ + self._clear_token() + return self.get() + + def close(self) -> None: + """Close the session and release resources.""" + if self._session: + self._session.close() + + +class OAuthPasswordGrantAsync(OAuthPasswordGrant, AuthAsync): + """Asynchronous implementation of OAuth 2.0 password grant authentication.""" + + _session: Optional[aiohttp.ClientSession] + + def __init__( + self, + oauth_uri: str, + username: str, + password: str, + client_id: str, + client_secret: Optional[str] = None, + scope: Optional[str] = None, + session: Optional[aiohttp.ClientSession] = None, + ) -> None: + """Initialize OAuth 2.0 password grant authentication. + + :param oauth_uri: OAuth 2.0 token endpoint URI. + :param username: Username for authentication. + :param password: Password for authentication. + :param client_id: OAuth client ID. + :param client_secret: Optional OAuth client secret. + :param scope: Optional OAuth scope. + :param session: Optional aiohttp.ClientSession instance to use. + """ + super().__init__(oauth_uri, username, password, client_id, client_secret, scope) + self._session = session + + def _get_session(self) -> aiohttp.ClientSession: + """Get or create aiohttp ClientSession. + + :return: ClientSession instance. + """ + if self._session is None or self._session.closed: + self._session = aiohttp.ClientSession() + return self._session + + async def _execute_token_request(self, data: dict) -> bool: + """Execute token request to OAuth server. + + :param data: Request data. + :return: True if request was successful, False otherwise. + """ + self._update_last_attempt_time() + + try: + session = self._get_session() + async with session.post( + self.oauth_uri, + data=data, + headers={"Content-Type": "application/x-www-form-urlencoded"}, + ) as response: + if not response.ok: + logger.warning(f"OAuth token request failed with status {response.status}") + return False + + response_data = await response.json() + return self._parse_token_response(response_data) + + except aiohttp.ClientError as e: + logger.warning(f"OAuth token request failed: {e}") + return False + + async def _obtain_new_token(self) -> bool: + """Obtain new access token using password grant. + + :return: True if token was obtained successfully, False otherwise. + """ + data = self._build_token_request_data( + grant_type="password", + username=self.username, + password=self.password, + ) + return await self._execute_token_request(data) + + async def _refresh_access_token(self) -> bool: + """Refresh access token using refresh token. + + :return: True if token was refreshed successfully, False otherwise. + """ + if not self._refresh_token: + return False + + data = self._build_token_request_data( + grant_type="refresh_token", + refresh_token=self._refresh_token, + ) + return await self._execute_token_request(data) + + async def get(self) -> Optional[str]: + """Get valid Authorization header value. + + :return: Authorization header value or None if authentication failed. + """ + # If token is valid, return it (no need to check throttling for cached token) + if not self._is_token_expired(): + return f"Bearer {self._access_token}" + + # Check if we should skip new request due to throttling + if self._should_skip_request(): + return None + + # Try to refresh token first + if await self._refresh_access_token(): + return f"Bearer {self._access_token}" + + # If refresh failed, try to obtain new token + if await self._obtain_new_token(): + return f"Bearer {self._access_token}" + + return None + + async def refresh(self) -> Optional[str]: + """Refresh the access token and return Authorization header value. + + :return: Authorization header value or None if refresh failed. + """ + self._clear_token() + return await self.get() + + async def close(self) -> None: + """Close the session and release resources.""" + if self._session and not self._session.closed: + await self._session.close() diff --git a/reportportal_client/_internal/services/client_id.py b/reportportal_client/_internal/services/client_id.py index 07a804b..ed7c0b7 100644 --- a/reportportal_client/_internal/services/client_id.py +++ b/reportportal_client/_internal/services/client_id.py @@ -57,6 +57,7 @@ def _read_client_id(): config = __read_config() if config.has_option(__NoSectionConfigParser.DEFAULT_SECTION, CLIENT_ID_PROPERTY): return config.get(__NoSectionConfigParser.DEFAULT_SECTION, CLIENT_ID_PROPERTY) + return None def _store_client_id(client_id): diff --git a/reportportal_client/_internal/services/statistics.py b/reportportal_client/_internal/services/statistics.py index a791eb1..b65f432 100644 --- a/reportportal_client/_internal/services/statistics.py +++ b/reportportal_client/_internal/services/statistics.py @@ -73,7 +73,9 @@ def _get_payload(event_name: str, agent_name: Optional[str], agent_version: Opti return {"client_id": get_client_id(), "events": [{"name": event_name, "params": request_params}]} -def send_event(event_name: str, agent_name: Optional[str], agent_version: Optional[str]) -> requests.Response: +def send_event( + event_name: str, agent_name: Optional[str], agent_version: Optional[str] +) -> Optional[requests.Response]: """Send an event to statistics service. Use client and agent versions with their names. @@ -93,6 +95,7 @@ def send_event(event_name: str, agent_name: Optional[str], agent_version: Option ) except requests.exceptions.RequestException as err: logger.debug("Failed to send data to Statistics service: %s", str(err)) + return None async def async_send_event( @@ -119,8 +122,8 @@ async def async_send_event( ssl=ssl_context, ) except aiohttp.ClientError as exc: - logger.debug("Failed to send data to Statistics service: connection error", exc) - return + logger.debug("Failed to send data to Statistics service: %s", str(exc)) + return None if not result.ok: logger.debug(f"Failed to send data to Statistics service: {result.reason}") return result diff --git a/reportportal_client/aio/client.py b/reportportal_client/aio/client.py index 3a90a1a..06f5583 100644 --- a/reportportal_client/aio/client.py +++ b/reportportal_client/aio/client.py @@ -26,7 +26,7 @@ import certifi # noinspection PyProtectedMember -from reportportal_client._internal.aio.http import RetryingClientSession +from reportportal_client._internal.aio.http import ClientSession, RetryingClientSession # noinspection PyProtectedMember from reportportal_client._internal.aio.tasks import ( @@ -44,6 +44,9 @@ # noinspection PyProtectedMember from reportportal_client._internal.logs.batcher import LogBatcher +# noinspection PyProtectedMember +from reportportal_client._internal.services.auth import ApiKeyAuthAsync, AuthAsync, OAuthPasswordGrantAsync + # noinspection PyProtectedMember from reportportal_client._internal.services.statistics import async_send_event @@ -101,7 +104,14 @@ class Client: endpoint: str is_skipped_an_issue: bool project: str - api_key: str + api_key: Optional[str] + oauth_uri: Optional[str] + oauth_username: Optional[str] + oauth_password: Optional[str] + oauth_client_id: Optional[str] + oauth_client_secret: Optional[str] + oauth_scope: Optional[str] + auth: AuthAsync verify_ssl: Union[bool, str] retries: Optional[int] max_pool_size: int @@ -112,7 +122,7 @@ class Client: print_output: OutputType truncate_attributes: bool _skip_analytics: str - _session: Optional[RetryingClientSession] + _session: Optional[ClientSession] __stat_task: Optional[asyncio.Task] def __init__( @@ -120,7 +130,7 @@ def __init__( endpoint: str, project: str, *, - api_key: str = None, + api_key: Optional[str] = None, is_skipped_an_issue: bool = True, verify_ssl: Union[bool, str] = True, retries: int = NOT_SET, @@ -131,13 +141,26 @@ def __init__( launch_uuid_print: bool = False, print_output: OutputType = OutputType.STDOUT, truncate_attributes: bool = True, - **_: Any, + # OAuth 2.0 Password Grant parameters + oauth_uri: Optional[str] = None, + oauth_username: Optional[str] = None, + oauth_password: Optional[str] = None, + oauth_client_id: Optional[str] = None, + oauth_client_secret: Optional[str] = None, + oauth_scope: Optional[str] = None, + **kwargs: Any, ) -> None: """Initialize the class instance with arguments. :param endpoint: Endpoint of the ReportPortal service. :param project: Project name to report to. :param api_key: Authorization API key. + :param oauth_uri: OAuth 2.0 token endpoint URI (for OAuth authentication). + :param oauth_username: Username for OAuth 2.0 authentication. + :param oauth_password: Password for OAuth 2.0 authentication. + :param oauth_client_id: OAuth 2.0 client ID. + :param oauth_client_secret: OAuth 2.0 client secret (optional). + :param oauth_scope: OAuth 2.0 scope (optional). :param is_skipped_an_issue: Option to mark skipped tests as not 'To Investigate' items on the server side. :param verify_ssl: Option to skip ssl verification. @@ -168,13 +191,62 @@ def __init__( self.print_output = print_output self._session = None self.__stat_task = None - self.api_key = api_key self.truncate_attributes = truncate_attributes - async def session(self) -> RetryingClientSession: - """Return aiohttp.ClientSession class instance, initialize it if necessary. + self.api_key = api_key + # Handle deprecated token argument + if not self.api_key and "token" in kwargs: + warnings.warn( + message="Argument `token` is deprecated since 5.3.5 and will be subject for removing in " + "the next major version. Use `api_key` argument instead.", + category=DeprecationWarning, + stacklevel=2, + ) + self.api_key = kwargs["token"] + + self.oauth_uri = oauth_uri + self.oauth_username = oauth_username + self.oauth_password = oauth_password + self.oauth_client_id = oauth_client_id + self.oauth_client_secret = oauth_client_secret + self.oauth_scope = oauth_scope + + # Initialize authentication + oauth_params = [oauth_uri, oauth_username, oauth_password, oauth_client_id] + oauth_provided = all(oauth_params) + + if oauth_provided: + # Use OAuth 2.0 Password Grant authentication + self.auth = OAuthPasswordGrantAsync( + oauth_uri=oauth_uri, + username=oauth_username, + password=oauth_password, + client_id=oauth_client_id, + client_secret=oauth_client_secret, + scope=oauth_scope, + ) + elif self.api_key: + self.auth = ApiKeyAuthAsync(api_key) + else: + # Neither OAuth nor API key provided + raise ValueError( + "Authentication credentials are required. Please provide either:\n" + "1. OAuth 2.0 parameters: oauth_uri, username, password, and client_id\n" + " (with optional client_secret and scope), or\n" + "2. api_key parameter for API key authentication.\n" + "\n" + "Example for OAuth:\n" + " Client(endpoint='...', project='...', oauth_oauth_uri='https://example.com/oauth/token',\n" + " oauth_username='user', oauth_password='pass', oauth_client_id='client_id')\n" + "\n" + "Example for API key:\n" + " Client(endpoint='...', project='...', api_key='your_api_key')" + ) + + async def session(self) -> ClientSession: + """Return ClientSession class instance, initialize it if necessary. - :return: aiohttp.ClientSession instance. + :return: ClientSession instance. """ if self._session: return self._session @@ -187,16 +259,12 @@ async def session(self) -> RetryingClientSession: else: ssl_config = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=certifi.where()) - connection_params = {"ssl": ssl_config, "limit": self.max_pool_size} + connection_params: Dict[str, Any] = {"ssl": ssl_config, "limit": self.max_pool_size} if self.keepalive_timeout: connection_params["keepalive_timeout"] = self.keepalive_timeout connector = aiohttp.TCPConnector(**connection_params) - headers = {} - if self.api_key: - headers["Authorization"] = f"Bearer {self.api_key}" - - session_params = {"headers": headers, "connector": connector} + session_params: Dict[str, Any] = {"connector": connector} if self.http_timeout: if type(self.http_timeout) is tuple: @@ -212,10 +280,12 @@ async def session(self) -> RetryingClientSession: session_params["max_retry_number"] = self.retries if use_retries: - self._session = RetryingClientSession(self.endpoint, **session_params) + wrapped_session = RetryingClientSession(self.endpoint, **session_params) else: # noinspection PyTypeChecker - self._session = aiohttp.ClientSession(self.endpoint, **session_params) + wrapped_session = aiohttp.ClientSession(self.endpoint, **session_params) + + self._session = ClientSession(wrapped=wrapped_session, auth=self.auth) return self._session async def close(self) -> None: @@ -565,14 +635,14 @@ async def log_batch(self, log_batch: Optional[List[AsyncRPRequestLog]]) -> Optio :param log_batch: A list of log message objects. :return: Completion message tuple of variable size (depending on request size). """ + if not log_batch: + return None + url = root_uri_join(self.base_url_v2, "log") - if log_batch: - response = await ErrorPrintingAsyncHttpRequest( - (await self.session()).post, url=url, data=AsyncRPLogBatch(log_batch).payload, name="log" - ).make() - if not response: - return None - return await response.messages + response = await ErrorPrintingAsyncHttpRequest( + (await self.session()).post, url=url, data=AsyncRPLogBatch(log_batch).payload, name="log" + ).make() + return await response.messages if response else None def clone(self) -> "Client": """Clone the client object, set current Item ID as cloned item ID. @@ -593,6 +663,12 @@ def clone(self) -> "Client": mode=self.mode, launch_uuid_print=self.launch_uuid_print, print_output=self.print_output, + oauth_uri=self.oauth_uri, + oauth_username=self.oauth_username, + oauth_password=self.oauth_password, + oauth_client_id=self.oauth_client_id, + oauth_client_secret=self.oauth_client_secret, + oauth_scope=self.oauth_scope, ) return cloned @@ -689,6 +765,12 @@ def __init__( :param endpoint: Endpoint of the ReportPortal service. :param project: Project name to report to. :param api_key: Authorization API key. + :param oauth_uri: OAuth 2.0 token endpoint URI (for OAuth authentication). + :param oauth_username: Username for OAuth 2.0 authentication. + :param oauth_password: Password for OAuth 2.0 authentication. + :param oauth_client_id: OAuth 2.0 client ID. + :param oauth_client_secret: OAuth 2.0 client secret (optional). + :param oauth_scope: OAuth 2.0 scope (optional). :param is_skipped_an_issue: Option to mark skipped tests as not 'To Investigate' items on the server side. :param verify_ssl: Option to skip ssl verification. @@ -1010,6 +1092,7 @@ async def close(self) -> None: await self.__client.close() +# noinspection PyAbstractClass class _RPClient(RP, metaclass=AbstractBaseClass): """Base class for different synchronous to asynchronous client implementations.""" @@ -1454,7 +1537,7 @@ class ThreadedRPClient(_RPClient): task_timeout: float shutdown_timeout: float - _task_list: BackgroundTaskList[Task[_T]] + _task_list: BackgroundTaskList[Task[Any]] _task_mutex: threading.RLock _loop: Optional[asyncio.AbstractEventLoop] _thread: Optional[threading.Thread] @@ -1510,6 +1593,12 @@ def __init__( :param endpoint: Endpoint of the ReportPortal service. :param project: Project name to report to. :param api_key: Authorization API key. + :param oauth_uri: OAuth 2.0 token endpoint URI (for OAuth authentication). + :param oauth_username: Username for OAuth 2.0 authentication. + :param oauth_password: Password for OAuth 2.0 authentication. + :param oauth_client_id: OAuth 2.0 client ID. + :param oauth_client_secret: OAuth 2.0 client secret (optional). + :param oauth_scope: OAuth 2.0 scope (optional). :param is_skipped_an_issue: Option to mark skipped tests as not 'To Investigate' items on the server side. :param verify_ssl: Option to skip ssl verification. @@ -1645,7 +1734,7 @@ class BatchedRPClient(_RPClient): trigger_interval: float _loop: asyncio.AbstractEventLoop _task_mutex: threading.RLock - _task_list: TriggerTaskBatcher[Task[_T]] + _task_list: TriggerTaskBatcher[Task[Any]] __last_run_time: float def __init_task_list( @@ -1692,6 +1781,12 @@ def __init__( :param endpoint: Endpoint of the ReportPortal service. :param project: Project name to report to. :param api_key: Authorization API key. + :param oauth_uri: OAuth 2.0 token endpoint URI (for OAuth authentication). + :param oauth_username: Username for OAuth 2.0 authentication. + :param oauth_password: Password for OAuth 2.0 authentication. + :param oauth_client_id: OAuth 2.0 client ID. + :param oauth_client_secret: OAuth 2.0 client secret (optional). + :param oauth_scope: OAuth 2.0 scope (optional). :param is_skipped_an_issue: Option to mark skipped tests as not 'To Investigate' items on the server side. :param verify_ssl: Option to skip ssl verification. diff --git a/reportportal_client/client.py b/reportportal_client/client.py index 02247ae..6252258 100644 --- a/reportportal_client/client.py +++ b/reportportal_client/client.py @@ -1,15 +1,16 @@ -# Copyright (c) 2023 EPAM Systems +# Copyright 2025 EPAM Systems +# # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# https://www.apache.org/licenses/LICENSE-2.0 +# https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License +# limitations under the License. """This module contains ReportPortal Client interface and synchronous implementation class.""" @@ -22,9 +23,11 @@ from typing import Any, Dict, List, Optional, TextIO, Tuple, Union import aenum -import requests from requests.adapters import DEFAULT_RETRIES, HTTPAdapter, Retry +# noinspection PyProtectedMember +from reportportal_client._internal.http import ClientSession + # noinspection PyProtectedMember from reportportal_client._internal.local import set_current @@ -32,12 +35,13 @@ from reportportal_client._internal.logs.batcher import LogBatcher # noinspection PyProtectedMember -from reportportal_client._internal.services.statistics import send_event +from reportportal_client._internal.services.auth import ApiKeyAuthSync, Auth, OAuthPasswordGrantSync # noinspection PyProtectedMember -from reportportal_client._internal.static.abstract import AbstractBaseClass +from reportportal_client._internal.services.statistics import send_event # noinspection PyProtectedMember +from reportportal_client._internal.static.abstract import AbstractBaseClass from reportportal_client.core.rp_issues import Issue from reportportal_client.core.rp_requests import ( ErrorPrintingHttpRequest, @@ -66,12 +70,13 @@ class OutputType(aenum.Enum): def get_output(self) -> Optional[TextIO]: """Return TextIO based on the current type.""" - if self == OutputType.STDOUT: - return sys.stdout if self == OutputType.STDERR: return sys.stderr + else: + return sys.stdout +# noinspection PyAbstractClass class RP(metaclass=AbstractBaseClass): """Common interface for ReportPortal clients. @@ -387,12 +392,19 @@ class RPClient(RP): log_batch_size: int log_batch_payload_size: int __project: str - api_key: str + api_key: Optional[str] + oauth_uri: Optional[str] + oauth_username: Optional[str] + oauth_password: Optional[str] + oauth_client_id: Optional[str] + oauth_client_secret: Optional[str] + oauth_scope: Optional[str] + auth: Auth verify_ssl: Union[bool, str] retries: int max_pool_size: int http_timeout: Union[float, Tuple[float, float]] - session: requests.Session + session: ClientSession __step_reporter: StepReporter mode: str launch_uuid_print: Optional[bool] @@ -440,19 +452,17 @@ def __init_session(self) -> None: if self.retries else DEFAULT_RETRIES ) - session = requests.Session() + session = ClientSession(auth=self.auth) session.mount("https://", HTTPAdapter(max_retries=retry_strategy, pool_maxsize=self.max_pool_size)) # noinspection HttpUrlsUsage session.mount("http://", HTTPAdapter(max_retries=retry_strategy, pool_maxsize=self.max_pool_size)) - if self.api_key: - session.headers["Authorization"] = "Bearer {0}".format(self.api_key) self.session = session def __init__( self, endpoint: str, project: str, - api_key: str = None, + api_key: Optional[str] = None, log_batch_size: int = 20, is_skipped_an_issue: bool = True, verify_ssl: Union[bool, str] = True, @@ -466,6 +476,13 @@ def __init__( print_output: OutputType = OutputType.STDOUT, log_batcher: Optional[LogBatcher[RPRequestLog]] = None, truncate_attributes: bool = True, + # OAuth 2.0 Password Grant parameters + oauth_uri: Optional[str] = None, + oauth_username: Optional[str] = None, + oauth_password: Optional[str] = None, + oauth_client_id: Optional[str] = None, + oauth_client_secret: Optional[str] = None, + oauth_scope: Optional[str] = None, **kwargs: Any, ) -> None: """Initialize the class instance with arguments. @@ -473,6 +490,12 @@ def __init__( :param endpoint: Endpoint of the ReportPortal service. :param project: Project name to report to. :param api_key: Authorization API key. + :param oauth_uri: OAuth 2.0 token endpoint URI (for OAuth authentication). + :param oauth_username: Username for OAuth 2.0 authentication. + :param oauth_password: Password for OAuth 2.0 authentication. + :param oauth_client_id: OAuth 2.0 client ID. + :param oauth_client_secret: OAuth 2.0 client secret (optional). + :param oauth_scope: OAuth 2.0 scope (optional). :param log_batch_size: Option to set the maximum number of logs that can be processed in one batch. :param is_skipped_an_issue: Option to mark skipped tests as not 'To Investigate' items on the @@ -525,24 +548,54 @@ def __init__( self.truncate_attributes = truncate_attributes self.api_key = api_key - if not self.api_key: - if "token" in kwargs: - warnings.warn( - message="Argument `token` is deprecated since 5.3.5 and will be subject for removing in " - "the next major version. Use `api_key` argument instead.", - category=DeprecationWarning, - stacklevel=2, - ) - self.api_key = kwargs["token"] - - if not self.api_key: - warnings.warn( - message="Argument `api_key` is `None` or empty string, that is not supposed to happen " - "because ReportPortal is usually requires an authorization key. Please check " - "your code.", - category=RuntimeWarning, - stacklevel=2, - ) + # Handle deprecated token argument + if not self.api_key and "token" in kwargs: + warnings.warn( + message="Argument `token` is deprecated since 5.3.5 and will be subject for removing in " + "the next major version. Use `api_key` argument instead.", + category=DeprecationWarning, + stacklevel=2, + ) + self.api_key = kwargs["token"] + + self.oauth_uri = oauth_uri + self.oauth_username = oauth_username + self.oauth_password = oauth_password + self.oauth_client_id = oauth_client_id + self.oauth_client_secret = oauth_client_secret + self.oauth_scope = oauth_scope + + # Initialize authentication + oauth_params = [oauth_uri, oauth_username, oauth_password, oauth_client_id] + oauth_provided = all(oauth_params) + + if oauth_provided: + # Use OAuth 2.0 Password Grant authentication + self.auth = OAuthPasswordGrantSync( + oauth_uri=oauth_uri, + username=oauth_username, + password=oauth_password, + client_id=oauth_client_id, + client_secret=oauth_client_secret, + scope=oauth_scope, + ) + elif self.api_key: + self.auth = ApiKeyAuthSync(api_key) + else: + # Neither OAuth nor API key provided + raise ValueError( + "Authentication credentials are required. Please provide either:\n" + "1. OAuth 2.0 parameters: oauth_uri, username, password, and client_id\n" + " (with optional client_secret and scope), or\n" + "2. api_key parameter for API key authentication.\n" + "\n" + "Example for OAuth:\n" + " RPClient(endpoint='...', project='...', oauth_uri='https://example.com/oauth/token',\n" + " username='user', password='pass', client_id='client_id')\n" + "\n" + "Example for API key:\n" + " RPClient(endpoint='...', project='...', api_key='your_api_key')" + ) self.__init_session() @@ -804,18 +857,19 @@ def update_test_item( return response.message def _log(self, batch: Optional[List[RPRequestLog]]) -> Optional[Tuple[str, ...]]: - if batch: - url = uri_join(self.base_url_v2, "log") - response = ErrorPrintingHttpRequest( - self.session.post, - url, - files=RPLogBatch(batch).payload, - verify_ssl=self.verify_ssl, - http_timeout=self.http_timeout, - name="log", - ).make() - if response: - return response.messages + if not batch: + return None + + url = uri_join(self.base_url_v2, "log") + response = ErrorPrintingHttpRequest( + self.session.post, + url, + files=RPLogBatch(batch).payload, + verify_ssl=self.verify_ssl, + http_timeout=self.http_timeout, + name="log", + ).make() + return response.messages if response else None def log( self, @@ -965,6 +1019,12 @@ def clone(self) -> "RPClient": log_batch_payload_size=self.log_batch_payload_size, mode=self.mode, log_batcher=self._log_batcher, + oauth_uri=self.oauth_uri, + oauth_username=self.oauth_username, + oauth_password=self.oauth_password, + oauth_client_id=self.oauth_client_id, + oauth_client_secret=self.oauth_client_secret, + oauth_scope=self.oauth_scope, ) current_item = self.current_item() if current_item: diff --git a/reportportal_client/helpers/common_helpers.py b/reportportal_client/helpers/common_helpers.py index b4d4dab..a36abf8 100644 --- a/reportportal_client/helpers/common_helpers.py +++ b/reportportal_client/helpers/common_helpers.py @@ -68,7 +68,7 @@ class LifoQueue(Generic[_T]): """Primitive thread-safe Last-in-first-out queue implementation.""" - _lock: threading.Lock() + _lock: threading.Lock __items: List[_T] def __init__(self): @@ -98,9 +98,13 @@ def last(self) -> _T: :return: The last element in the queue. """ + if len(self.__items) <= 0: + return None + with self._lock: if len(self.__items) > 0: return self.__items[-1] + return None def qsize(self): """Return the queue size.""" @@ -149,7 +153,7 @@ def dict_to_payload(dictionary: Optional[dict]) -> Optional[List[dict]]: hidden = my_dictionary.pop("system", None) result = [] for key, value in sorted(my_dictionary.items()): - attribute = {"key": str(key), "value": str(value)} + attribute: Dict[str, Any] = {"key": str(key), "value": str(value)} if hidden is not None: attribute["system"] = hidden result.append(attribute) diff --git a/requirements-dev.txt b/requirements-dev.txt index e01a820..b72c70d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,5 @@ pytest pytest-cov pytest-asyncio +black +isort diff --git a/requirements.txt b/requirements.txt index 5871f2d..1012378 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ aenum +typing-extensions>=4.13.2 requests>=2.32.3 aiohttp>=3.10.11 certifi>=2024.8.30 diff --git a/setup.py b/setup.py index 98db5f0..1dfea31 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ from setuptools import find_packages, setup -__version__ = "5.6.5" +__version__ = "5.6.6" TYPE_STUBS = ["*.pyi"] @@ -36,7 +36,7 @@ def read_file(fname): author_email="support@reportportal.io", url="https://github.com/reportportal/client-Python", download_url=("https://github.com/reportportal/client-Python/" "tarball/%s" % __version__), - license="Apache 2.0.", + license="Apache-2.0", keywords=["testing", "reporting", "reportportal", "client"], classifiers=[ "Programming Language :: Python :: 3.8", diff --git a/tests/_internal/aio/test_http.py b/tests/_internal/aio/test_aio_http.py similarity index 55% rename from tests/_internal/aio/test_http.py rename to tests/_internal/aio/test_aio_http.py index 1be66c6..44d3601 100644 --- a/tests/_internal/aio/test_http.py +++ b/tests/_internal/aio/test_aio_http.py @@ -30,7 +30,10 @@ import pytest # noinspection PyProtectedMember -from reportportal_client._internal.aio.http import RetryingClientSession +from reportportal_client._internal.aio.http import ClientSession, RetryingClientSession + +# noinspection PyProtectedMember +from reportportal_client._internal.services.auth import ApiKeyAuthAsync HTTP_TIMEOUT_TIME = 1.2 @@ -75,14 +78,27 @@ def do_GET(self): self.wfile.flush() +class UnauthorizedHttpHandler(http.server.BaseHTTPRequestHandler): + def do_GET(self): + auth_header = self.headers.get("Authorization") + if auth_header == "Bearer test_api_key": + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write("{}\n\n".encode("utf-8")) + else: + self.send_response(401, "Unauthorized") + self.end_headers() + self.wfile.write("Unauthorized\n\n".encode("utf-8")) + self.wfile.flush() + + SERVER_PORT = 8000 SERVER_ADDRESS = ("", SERVER_PORT) -SERVER_CLASS = socketserver.TCPServer -SERVER_HANDLER_CLASS = http.server.BaseHTTPRequestHandler -def get_http_server(server_class=SERVER_CLASS, server_address=SERVER_ADDRESS, server_handler=SERVER_HANDLER_CLASS): - httpd = server_class(server_address, server_handler) +def get_http_server(*, server_handler, server_address=SERVER_ADDRESS): + httpd = socketserver.TCPServer(server_address, server_handler) thread = threading.Thread(target=httpd.serve_forever, daemon=True) thread.start() return httpd @@ -105,7 +121,7 @@ async def execute_http_request(port, retry_number, server_class, timeout_seconds exception = None result = None with get_http_server(server_handler=server_class, server_address=("", port)): - with mock.patch("reportportal_client._internal.aio.http.ClientSession.get", async_mock): + with mock.patch("reportportal_client._internal.aio.http.AioHttpClientSession.get", async_mock): async with session: start_time = time.time() try: @@ -163,3 +179,93 @@ async def test_no_retry_on_not_retryable_error(): assert result is None assert async_mock.call_count == 1 assert total_time < 1 + + +@pytest.mark.asyncio +async def test_auth_header_added_to_request(): + """Test that auth header is added to requests when auth is configured.""" + port = 8006 + retry_number = 5 + auth = ApiKeyAuthAsync("test_api_key") + timeout = aiohttp.ClientTimeout(connect=1.0, sock_read=1.0) + connector = aiohttp.TCPConnector(force_close=True) + wrapped_session = RetryingClientSession( + f"http://localhost:{port}", + timeout=timeout, + max_retry_number=retry_number, + base_retry_delay=0.01, + connector=connector, + ) + session = ClientSession(wrapped=wrapped_session, auth=auth) + + with get_http_server(server_handler=UnauthorizedHttpHandler, server_address=("", port)): + async with session: + result = await session.get("/") + assert result.ok + assert result.status == 200 + + +@pytest.mark.asyncio +async def test_auth_refresh_on_401(): + """Test that 401 response triggers auth refresh.""" + port = 8007 + retry_number = 5 + + # Create a mock auth that fails first, then succeeds + auth = mock.AsyncMock() + auth.get = mock.AsyncMock(side_effect=["Bearer invalid_token", "Bearer test_api_key"]) + auth.refresh = mock.AsyncMock(return_value="Bearer test_api_key") + + timeout = aiohttp.ClientTimeout(connect=1.0, sock_read=1.0) + connector = aiohttp.TCPConnector(force_close=True) + wrapped_session = RetryingClientSession( + f"http://localhost:{port}", + timeout=timeout, + max_retry_number=retry_number, + base_retry_delay=0.01, + connector=connector, + ) + session = ClientSession(wrapped=wrapped_session, auth=auth) + + with get_http_server(server_handler=UnauthorizedHttpHandler, server_address=("", port)): + async with session: + result = await session.get("/") + # First call to get() returns invalid token, which causes 401 + # Then refresh() is called and returns valid token + # Request is retried with valid token and succeeds + assert result.ok + assert result.status == 200 + assert auth.get.call_count == 1 + assert auth.refresh.call_count == 1 + + +@pytest.mark.asyncio +async def test_auth_refresh_only_once(): + """Test that auth refresh is only performed once per request.""" + port = 8008 + retry_number = 5 + + # Create a mock auth that always fails + auth = mock.AsyncMock() + auth.get = mock.AsyncMock(return_value="Bearer invalid_token") + auth.refresh = mock.AsyncMock(return_value="Bearer still_invalid_token") + + timeout = aiohttp.ClientTimeout(connect=1.0, sock_read=1.0) + connector = aiohttp.TCPConnector(force_close=True) + wrapped_session = RetryingClientSession( + f"http://localhost:{port}", + timeout=timeout, + max_retry_number=retry_number, + base_retry_delay=0.01, + connector=connector, + ) + session = ClientSession(wrapped=wrapped_session, auth=auth) + + with get_http_server(server_handler=UnauthorizedHttpHandler, server_address=("", port)): + async with session: + result = await session.get("/") + # Auth refresh should only be attempted once + assert not result.ok + assert result.status == 401 + assert auth.get.call_count == 1 + assert auth.refresh.call_count == 1 diff --git a/tests/_internal/services/test_auth.py b/tests/_internal/services/test_auth.py new file mode 100644 index 0000000..bbaede0 --- /dev/null +++ b/tests/_internal/services/test_auth.py @@ -0,0 +1,601 @@ +"""This module contains unit tests for authentication.""" + +# Copyright 2025 EPAM Systems +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License + +import asyncio +import time +from unittest import mock + +# noinspection PyPackageRequirements +import pytest + +# noinspection PyProtectedMember +from reportportal_client._internal.services.auth import ( + ApiKeyAuthAsync, + ApiKeyAuthSync, + OAuthPasswordGrantAsync, + OAuthPasswordGrantSync, +) + +OAUTH_URI = "https://example.com/oauth/token" +USERNAME = "testuser" +PASSWORD = "testpass" +CLIENT_ID = "test_client_id" +CLIENT_SECRET = "test_client_secret" +SCOPE = "test_scope" +ACCESS_TOKEN = "test_access_token" +REFRESH_TOKEN = "test_refresh_token" +EXPIRES_IN = 3600 + + +def create_token_response(access_token=ACCESS_TOKEN, refresh_token=REFRESH_TOKEN, expires_in=EXPIRES_IN): + """Create a mock OAuth token response.""" + return {"access_token": access_token, "refresh_token": refresh_token, "expires_in": expires_in} + + +class TestOAuthPasswordGrantSync: + """Tests for synchronous OAuth 2.0 password grant authentication.""" + + def test_happy_path_fresh_start(self): + """Test successful token acquisition on fresh start.""" + mock_response = mock.Mock() + mock_response.ok = True + mock_response.json.return_value = create_token_response() + + mock_session = mock.Mock() + mock_session.post.return_value = mock_response + + oauth = OAuthPasswordGrantSync( + OAUTH_URI, USERNAME, PASSWORD, CLIENT_ID, CLIENT_SECRET, SCOPE, session=mock_session + ) + result = oauth.get() + + assert result == f"Bearer {ACCESS_TOKEN}" + assert oauth._access_token == ACCESS_TOKEN + assert oauth._refresh_token == REFRESH_TOKEN + mock_session.post.assert_called_once() + + # Verify request data + call_args = mock_session.post.call_args + assert call_args[0][0] == OAUTH_URI + data = call_args[1]["data"] + assert data["grant_type"] == "password" + assert data["username"] == USERNAME + assert data["password"] == PASSWORD + assert data["client_id"] == CLIENT_ID + assert data["client_secret"] == CLIENT_SECRET + assert data["scope"] == SCOPE + + def test_happy_path_token_refresh(self): + """Test successful token refresh after expiration.""" + mock_response = mock.Mock() + mock_response.ok = True + mock_response.json.return_value = create_token_response(expires_in=1) + + mock_session = mock.Mock() + mock_session.post.return_value = mock_response + + oauth = OAuthPasswordGrantSync(OAUTH_URI, USERNAME, PASSWORD, CLIENT_ID, session=mock_session) + + # First call - obtain token + result1 = oauth.get() + assert result1 == f"Bearer {ACCESS_TOKEN}" + assert mock_session.post.call_count == 1 + + # Wait for token to expire + time.sleep(1) + + # Second call - token expired, should refresh + new_access_token = "new_access_token" + mock_response.json.return_value = create_token_response(access_token=new_access_token) + + result2 = oauth.get() + assert result2 == f"Bearer {new_access_token}" + assert oauth._access_token == new_access_token + assert mock_session.post.call_count == 2 + + # Verify refresh token was used + call_args = mock_session.post.call_args_list[1] + data = call_args[1]["data"] + assert data["grant_type"] == "refresh_token" + assert data["refresh_token"] == REFRESH_TOKEN + + def test_token_refresh_throttling(self): + """Test that token requests are throttled within the same second after failure.""" + mock_response_fail = mock.Mock() + mock_response_fail.ok = False + mock_response_fail.status_code = 401 + + mock_session = mock.Mock() + mock_session.post.return_value = mock_response_fail + + oauth = OAuthPasswordGrantSync(OAUTH_URI, USERNAME, PASSWORD, CLIENT_ID, session=mock_session) + + # First call - should fail + result1 = oauth.get() + assert result1 is None + assert mock_session.post.call_count == 1 + + # Second call in the same second - should be throttled + result2 = oauth.get() + assert result2 is None + assert mock_session.post.call_count == 1 # No additional call + + def test_initial_request_401_response(self): + """Test 401 response from server on initial request.""" + mock_response = mock.Mock() + mock_response.ok = False + mock_response.status_code = 401 + + mock_session = mock.Mock() + mock_session.post.return_value = mock_response + + oauth = OAuthPasswordGrantSync(OAUTH_URI, USERNAME, PASSWORD, CLIENT_ID, session=mock_session) + result = oauth.get() + + assert result is None + assert oauth._access_token is None + + def test_initial_request_403_response(self): + """Test 403 response from server on initial request.""" + mock_response = mock.Mock() + mock_response.ok = False + mock_response.status_code = 403 + + mock_session = mock.Mock() + mock_session.post.return_value = mock_response + + oauth = OAuthPasswordGrantSync(OAUTH_URI, USERNAME, PASSWORD, CLIENT_ID, session=mock_session) + result = oauth.get() + + assert result is None + assert oauth._access_token is None + + def test_refresh_request_401_fallback_to_password_grant(self): + """Test 401 response on refresh request falls back to password grant.""" + # First obtain a token + initial_response = mock.Mock() + initial_response.ok = True + initial_response.json.return_value = create_token_response(expires_in=1) + + mock_session = mock.Mock() + mock_session.post.return_value = initial_response + + oauth = OAuthPasswordGrantSync(OAUTH_URI, USERNAME, PASSWORD, CLIENT_ID, session=mock_session) + result1 = oauth.get() + assert result1 == f"Bearer {ACCESS_TOKEN}" + + # Wait for token to expire + time.sleep(1) + + # Simulate refresh failure and password grant success + refresh_response = mock.Mock() + refresh_response.ok = False + refresh_response.status_code = 401 + + password_response = mock.Mock() + password_response.ok = True + new_token = "new_password_token" + password_response.json.return_value = create_token_response(access_token=new_token) + + mock_session.post.side_effect = [refresh_response, password_response] + + result2 = oauth.get() + assert result2 == f"Bearer {new_token}" + assert oauth._access_token == new_token + assert mock_session.post.call_count == 3 # Initial + refresh attempt + password grant + + def test_refresh_request_403_fallback_to_password_grant(self): + """Test 403 response on refresh request falls back to password grant.""" + # First obtain a token + initial_response = mock.Mock() + initial_response.ok = True + initial_response.json.return_value = create_token_response(expires_in=1) + + mock_session = mock.Mock() + mock_session.post.return_value = initial_response + + oauth = OAuthPasswordGrantSync(OAUTH_URI, USERNAME, PASSWORD, CLIENT_ID, session=mock_session) + result1 = oauth.get() + assert result1 == f"Bearer {ACCESS_TOKEN}" + + # Wait for token to expire + time.sleep(1) + + # Simulate refresh failure with 403 and password grant success + refresh_response = mock.Mock() + refresh_response.ok = False + refresh_response.status_code = 403 + + password_response = mock.Mock() + password_response.ok = True + new_token = "new_password_token" + password_response.json.return_value = create_token_response(access_token=new_token) + + mock_session.post.side_effect = [refresh_response, password_response] + + result2 = oauth.get() + assert result2 == f"Bearer {new_token}" + assert oauth._access_token == new_token + + def test_refresh_method_on_valid_token(self): + """Test refresh method call on existing and not expired access token.""" + initial_response = mock.Mock() + initial_response.ok = True + initial_response.json.return_value = create_token_response() + + mock_session = mock.Mock() + mock_session.post.return_value = initial_response + + oauth = OAuthPasswordGrantSync(OAUTH_URI, USERNAME, PASSWORD, CLIENT_ID, session=mock_session) + + # Get initial token + result1 = oauth.get() + assert result1 == f"Bearer {ACCESS_TOKEN}" + assert oauth._access_token == ACCESS_TOKEN + + # Call refresh on valid token + new_token = "refreshed_token" + refreshed_response = mock.Mock() + refreshed_response.ok = True + refreshed_response.json.return_value = create_token_response(access_token=new_token) + mock_session.post.return_value = refreshed_response + + # Wait to avoid throttling + time.sleep(1) + + result2 = oauth.refresh() + assert result2 == f"Bearer {new_token}" + assert oauth._access_token == new_token + + # Verify it tried to use refresh token (since refresh clears the access token but refresh token is still + # available) + call_args = mock_session.post.call_args + data = call_args[1]["data"] + # After refresh() clears the access token, it should try to use refresh_token grant since refresh token is + # still available + assert data["grant_type"] == "refresh_token" + assert data["refresh_token"] == REFRESH_TOKEN + + +class TestOAuthPasswordGrantAsync: + """Tests for asynchronous OAuth 2.0 password grant authentication.""" + + @pytest.mark.asyncio + async def test_happy_path_fresh_start(self): + """Test successful token acquisition on fresh start.""" + mock_response = mock.AsyncMock() + mock_response.ok = True + mock_response.json.return_value = create_token_response() + mock_response.__aenter__.return_value = mock_response + mock_response.__aexit__.return_value = None + + mock_session = mock.Mock() + mock_session.post = mock.Mock(return_value=mock_response) + mock_session.closed = False + + oauth = OAuthPasswordGrantAsync( + OAUTH_URI, USERNAME, PASSWORD, CLIENT_ID, CLIENT_SECRET, SCOPE, session=mock_session + ) + result = await oauth.get() + + assert result == f"Bearer {ACCESS_TOKEN}" + assert oauth._access_token == ACCESS_TOKEN + assert oauth._refresh_token == REFRESH_TOKEN + mock_session.post.assert_called_once() + + # Verify request data + call_args = mock_session.post.call_args + data = call_args[1]["data"] + assert data["grant_type"] == "password" + assert data["username"] == USERNAME + assert data["password"] == PASSWORD + assert data["client_id"] == CLIENT_ID + assert data["client_secret"] == CLIENT_SECRET + assert data["scope"] == SCOPE + + @pytest.mark.asyncio + async def test_happy_path_token_refresh(self): + """Test successful token refresh after expiration.""" + mock_response = mock.AsyncMock() + mock_response.ok = True + mock_response.json.return_value = create_token_response(expires_in=1) + mock_response.__aenter__.return_value = mock_response + mock_response.__aexit__.return_value = None + + mock_session = mock.Mock() + mock_session.post = mock.Mock(return_value=mock_response) + mock_session.closed = False + + oauth = OAuthPasswordGrantAsync(OAUTH_URI, USERNAME, PASSWORD, CLIENT_ID, session=mock_session) + + # First call - obtain token + result1 = await oauth.get() + assert result1 == f"Bearer {ACCESS_TOKEN}" + assert mock_session.post.call_count == 1 + + # Wait for token to expire + await asyncio.sleep(1) + + # Second call - token expired, should refresh + new_access_token = "new_access_token" + mock_response.json.return_value = create_token_response(access_token=new_access_token) + + result2 = await oauth.get() + assert result2 == f"Bearer {new_access_token}" + assert oauth._access_token == new_access_token + assert mock_session.post.call_count == 2 + + # Verify refresh token was used + call_args = mock_session.post.call_args_list[1] + data = call_args[1]["data"] + assert data["grant_type"] == "refresh_token" + assert data["refresh_token"] == REFRESH_TOKEN + + @pytest.mark.asyncio + async def test_token_refresh_throttling(self): + """Test that token requests are throttled within the same second after failure.""" + mock_response_fail = mock.AsyncMock() + mock_response_fail.ok = False + mock_response_fail.status = 401 + mock_response_fail.__aenter__.return_value = mock_response_fail + mock_response_fail.__aexit__.return_value = None + + mock_session = mock.Mock() + mock_session.post = mock.Mock(return_value=mock_response_fail) + mock_session.closed = False + + oauth = OAuthPasswordGrantAsync(OAUTH_URI, USERNAME, PASSWORD, CLIENT_ID, session=mock_session) + + # First call - should fail + result1 = await oauth.get() + assert result1 is None + assert mock_session.post.call_count == 1 + + # Second call in the same second - should be throttled + result2 = await oauth.get() + assert result2 is None + assert mock_session.post.call_count == 1 # No additional call + + @pytest.mark.asyncio + async def test_initial_request_401_response(self): + """Test 401 response from server on initial request.""" + mock_response = mock.AsyncMock() + mock_response.ok = False + mock_response.status = 401 + mock_response.__aenter__.return_value = mock_response + mock_response.__aexit__.return_value = None + + mock_session = mock.Mock() + mock_session.post = mock.Mock(return_value=mock_response) + mock_session.closed = False + + oauth = OAuthPasswordGrantAsync(OAUTH_URI, USERNAME, PASSWORD, CLIENT_ID, session=mock_session) + result = await oauth.get() + + assert result is None + assert oauth._access_token is None + + @pytest.mark.asyncio + async def test_initial_request_403_response(self): + """Test 403 response from server on initial request.""" + mock_response = mock.AsyncMock() + mock_response.ok = False + mock_response.status = 403 + mock_response.__aenter__.return_value = mock_response + mock_response.__aexit__.return_value = None + + mock_session = mock.Mock() + mock_session.post = mock.Mock(return_value=mock_response) + mock_session.closed = False + + oauth = OAuthPasswordGrantAsync(OAUTH_URI, USERNAME, PASSWORD, CLIENT_ID, session=mock_session) + result = await oauth.get() + + assert result is None + assert oauth._access_token is None + + @pytest.mark.asyncio + async def test_refresh_request_401_fallback_to_password_grant(self): + """Test 401 response on refresh request falls back to password grant.""" + # First obtain a token + initial_response = mock.AsyncMock() + initial_response.ok = True + initial_response.json.return_value = create_token_response(expires_in=1) + initial_response.__aenter__.return_value = initial_response + initial_response.__aexit__.return_value = None + + mock_session = mock.Mock() + mock_session.post = mock.Mock(return_value=initial_response) + mock_session.closed = False + + oauth = OAuthPasswordGrantAsync(OAUTH_URI, USERNAME, PASSWORD, CLIENT_ID, session=mock_session) + result1 = await oauth.get() + assert result1 == f"Bearer {ACCESS_TOKEN}" + + # Wait for token to expire + await asyncio.sleep(1) + + # Simulate refresh failure and password grant success + refresh_response = mock.AsyncMock() + refresh_response.ok = False + refresh_response.status = 401 + refresh_response.__aenter__.return_value = refresh_response + refresh_response.__aexit__.return_value = None + + password_response = mock.AsyncMock() + password_response.ok = True + new_token = "new_password_token" + password_response.json.return_value = create_token_response(access_token=new_token) + password_response.__aenter__.return_value = password_response + password_response.__aexit__.return_value = None + + mock_session.post.side_effect = [refresh_response, password_response] + + result2 = await oauth.get() + assert result2 == f"Bearer {new_token}" + assert oauth._access_token == new_token + assert mock_session.post.call_count == 3 # Initial + refresh attempt + password grant + + @pytest.mark.asyncio + async def test_refresh_request_403_fallback_to_password_grant(self): + """Test 403 response on refresh request falls back to password grant.""" + # First obtain a token + initial_response = mock.AsyncMock() + initial_response.ok = True + initial_response.json.return_value = create_token_response(expires_in=1) + initial_response.__aenter__.return_value = initial_response + initial_response.__aexit__.return_value = None + + mock_session = mock.Mock() + mock_session.post = mock.Mock(return_value=initial_response) + mock_session.closed = False + + oauth = OAuthPasswordGrantAsync(OAUTH_URI, USERNAME, PASSWORD, CLIENT_ID, session=mock_session) + result1 = await oauth.get() + assert result1 == f"Bearer {ACCESS_TOKEN}" + + # Wait for token to expire + await asyncio.sleep(1) + + # Simulate refresh failure with 403 and password grant success + refresh_response = mock.AsyncMock() + refresh_response.ok = False + refresh_response.status = 403 + refresh_response.__aenter__.return_value = refresh_response + refresh_response.__aexit__.return_value = None + + password_response = mock.AsyncMock() + password_response.ok = True + new_token = "new_password_token" + password_response.json.return_value = create_token_response(access_token=new_token) + password_response.__aenter__.return_value = password_response + password_response.__aexit__.return_value = None + + mock_session.post.side_effect = [refresh_response, password_response] + + result2 = await oauth.get() + assert result2 == f"Bearer {new_token}" + assert oauth._access_token == new_token + + @pytest.mark.asyncio + async def test_refresh_method_on_valid_token(self): + """Test refresh method call on existing and not expired access token.""" + initial_response = mock.AsyncMock() + initial_response.ok = True + initial_response.json.return_value = create_token_response() + initial_response.__aenter__.return_value = initial_response + initial_response.__aexit__.return_value = None + + mock_session = mock.Mock() + mock_session.post = mock.Mock(return_value=initial_response) + mock_session.closed = False + + oauth = OAuthPasswordGrantAsync(OAUTH_URI, USERNAME, PASSWORD, CLIENT_ID, session=mock_session) + + # Get initial token + result1 = await oauth.get() + assert result1 == f"Bearer {ACCESS_TOKEN}" + assert oauth._access_token == ACCESS_TOKEN + + # Call refresh on valid token + new_token = "refreshed_token" + refreshed_response = mock.AsyncMock() + refreshed_response.ok = True + refreshed_response.json.return_value = create_token_response(access_token=new_token) + refreshed_response.__aenter__.return_value = refreshed_response + refreshed_response.__aexit__.return_value = None + mock_session.post.return_value = refreshed_response + + # Wait to avoid throttling + await asyncio.sleep(1) + + result2 = await oauth.refresh() + assert result2 == f"Bearer {new_token}" + assert oauth._access_token == new_token + + # Verify it tried to use refresh token (since refresh clears the access token but refresh token is still + # available) + call_args = mock_session.post.call_args + data = call_args[1]["data"] + # After refresh() clears the access token, it should try to use refresh_token grant since refresh token is + # still available + assert data["grant_type"] == "refresh_token" + assert data["refresh_token"] == REFRESH_TOKEN + + +class TestApiTokenAuthSync: + """Tests for synchronous API token authentication.""" + + def test_get_returns_token(self): + """Test that get() returns the API token.""" + api_token = "test_api_token_12345" + auth = ApiKeyAuthSync(api_token) + result = auth.get() + + assert result == f"Bearer {api_token}" + + def test_refresh_returns_none(self): + """Test that refresh() returns None (API keys don't have refresh mechanism).""" + api_token = "test_api_token_67890" + auth = ApiKeyAuthSync(api_token) + result = auth.refresh() + + assert result is None + + def test_multiple_calls_return_same_token(self): + """Test that multiple get() calls return the same token.""" + api_token = "test_api_token_stable" + auth = ApiKeyAuthSync(api_token) + + result1 = auth.get() + result2 = auth.get() + + assert result1 == f"Bearer {api_token}" + assert result2 == f"Bearer {api_token}" + + +class TestApiTokenAuthAsync: + """Tests for asynchronous API token authentication.""" + + @pytest.mark.asyncio + async def test_get_returns_token(self): + """Test that get() returns the API token.""" + api_token = "test_api_token_async_12345" + auth = ApiKeyAuthAsync(api_token) + result = await auth.get() + + assert result == f"Bearer {api_token}" + + @pytest.mark.asyncio + async def test_refresh_returns_none(self): + """Test that refresh() returns None (API keys don't have refresh mechanism).""" + api_token = "test_api_token_async_67890" + auth = ApiKeyAuthAsync(api_token) + result = await auth.refresh() + + assert result is None + + @pytest.mark.asyncio + async def test_multiple_calls_return_same_token(self): + """Test that multiple get() calls return the same token.""" + api_token = "test_api_token_async_stable" + auth = ApiKeyAuthAsync(api_token) + + result1 = await auth.get() + result2 = await auth.get() + + assert result1 == f"Bearer {api_token}" + assert result2 == f"Bearer {api_token}" diff --git a/tests/_internal/test_http.py b/tests/_internal/test_http.py new file mode 100644 index 0000000..e8e63e5 --- /dev/null +++ b/tests/_internal/test_http.py @@ -0,0 +1,250 @@ +# Copyright (c) 2023 EPAM Systems +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +import http.server +import socketserver +import threading +from unittest import mock + +# noinspection PyProtectedMember +from reportportal_client._internal.http import ClientSession +from reportportal_client._internal.services.auth import ApiKeyAuthSync + + +class OkHttpHandler(http.server.BaseHTTPRequestHandler): + def do_GET(self): + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write("{}\n\n".encode("utf-8")) + self.wfile.flush() + + +class UnauthorizedHttpHandler(http.server.BaseHTTPRequestHandler): + def do_GET(self): + auth_header = self.headers.get("Authorization") + if auth_header == "Bearer test_api_key": + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write("{}\n\n".encode("utf-8")) + else: + self.send_response(401, "Unauthorized") + self.end_headers() + self.wfile.write("Unauthorized\n\n".encode("utf-8")) + self.wfile.flush() + + +SERVER_PORT = 10000 +SERVER_ADDRESS = ("", SERVER_PORT) +SERVER_CLASS = socketserver.TCPServer + + +# Allow socket reuse to avoid "Address already in use" errors +class ReuseAddrTCPServer(socketserver.TCPServer): + allow_reuse_address = True + + +def get_http_server( + *, + server_handler, + server_address=SERVER_ADDRESS, +): + httpd = ReuseAddrTCPServer(server_address, server_handler) + thread = threading.Thread(target=httpd.serve_forever, daemon=True) + thread.start() + return httpd + + +def test_no_auth_request(): + """Test that requests work without authentication.""" + port = 10000 + session = ClientSession() + + with get_http_server(server_handler=OkHttpHandler, server_address=("", port)): + with session: + result = session.get(f"http://localhost:{port}/") + assert result.ok + assert result.status_code == 200 + + +def test_auth_header_added_to_request(): + """Test that auth header is added to requests when auth is configured.""" + port = 10001 + auth = ApiKeyAuthSync("test_api_key") + session = ClientSession(auth=auth) + + with get_http_server(server_handler=UnauthorizedHttpHandler, server_address=("", port)): + with session: + result = session.get(f"http://localhost:{port}/") + assert result.ok + assert result.status_code == 200 + + +def test_auth_refresh_on_401(): + """Test that 401 response triggers auth refresh.""" + port = 10002 + + # Create a mock auth that fails first, then succeeds + auth = mock.Mock() + auth.get = mock.Mock(side_effect=["Bearer invalid_token", "Bearer test_api_key"]) + auth.refresh = mock.Mock(return_value="Bearer test_api_key") + + session = ClientSession(auth=auth) + + with get_http_server(server_handler=UnauthorizedHttpHandler, server_address=("", port)): + with session: + result = session.get(f"http://localhost:{port}/") + # First call to get() returns invalid token, which causes 401 + # Then refresh() is called and returns valid token + # Request is retried with valid token and succeeds + assert result.ok + assert result.status_code == 200 + assert auth.get.call_count == 1 + assert auth.refresh.call_count == 1 + + +def test_auth_refresh_only_once(): + """Test that auth refresh is only performed once per request.""" + port = 10003 + + # Create a mock auth that always fails + auth = mock.Mock() + auth.get = mock.Mock(return_value="Bearer invalid_token") + auth.refresh = mock.Mock(return_value="Bearer still_invalid_token") + + session = ClientSession(auth=auth) + + with get_http_server(server_handler=UnauthorizedHttpHandler, server_address=("", port)): + with session: + result = session.get(f"http://localhost:{port}/") + # Auth refresh should only be attempted once + assert not result.ok + assert result.status_code == 401 + assert auth.get.call_count == 1 + assert auth.refresh.call_count == 1 + + +def test_post_request_with_auth(): + """Test that POST requests work with authentication.""" + port = 10004 + auth = ApiKeyAuthSync("test_api_key") + session = ClientSession(auth=auth) + + class PostHandler(http.server.BaseHTTPRequestHandler): + def do_POST(self): + # Read the request body to avoid connection reset + content_length = int(self.headers.get("Content-Length", 0)) + if content_length: + self.rfile.read(content_length) + + auth_header = self.headers.get("Authorization") + if auth_header == "Bearer test_api_key": + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write("{}\n\n".encode("utf-8")) + else: + self.send_response(401, "Unauthorized") + self.end_headers() + self.wfile.write("Unauthorized\n\n".encode("utf-8")) + self.wfile.flush() + + with get_http_server(server_handler=PostHandler, server_address=("", port)): + with session: + result = session.post(f"http://localhost:{port}/", data={"test": "data"}) + assert result.ok + assert result.status_code == 200 + + +def test_put_request_with_auth(): + """Test that PUT requests work with authentication.""" + port = 10005 + auth = ApiKeyAuthSync("test_api_key") + session = ClientSession(auth=auth) + + class PutHandler(http.server.BaseHTTPRequestHandler): + def do_PUT(self): + # Read the request body to avoid connection reset + content_length = int(self.headers.get("Content-Length", 0)) + if content_length: + self.rfile.read(content_length) + + auth_header = self.headers.get("Authorization") + if auth_header == "Bearer test_api_key": + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write("{}\n\n".encode("utf-8")) + else: + self.send_response(401, "Unauthorized") + self.end_headers() + self.wfile.write("Unauthorized\n\n".encode("utf-8")) + self.wfile.flush() + + with get_http_server(server_handler=PutHandler, server_address=("", port)): + with session: + result = session.put(f"http://localhost:{port}/", data={"test": "data"}) + assert result.ok + assert result.status_code == 200 + + +def test_403_triggers_auth_refresh(): + """Test that 403 response also triggers auth refresh.""" + port = 10006 + + class ForbiddenHttpHandler(http.server.BaseHTTPRequestHandler): + def do_GET(self): + auth_header = self.headers.get("Authorization") + if auth_header == "Bearer test_api_key": + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write("{}\n\n".encode("utf-8")) + else: + self.send_response(403, "Forbidden") + self.end_headers() + self.wfile.write("Forbidden\n\n".encode("utf-8")) + self.wfile.flush() + + # Create a mock auth that fails first, then succeeds + auth = mock.Mock() + auth.get = mock.Mock(side_effect=["Bearer invalid_token", "Bearer test_api_key"]) + auth.refresh = mock.Mock(return_value="Bearer test_api_key") + + session = ClientSession(auth=auth) + + with get_http_server(server_handler=ForbiddenHttpHandler, server_address=("", port)): + with session: + result = session.get(f"http://localhost:{port}/") + assert result.ok + assert result.status_code == 200 + assert auth.get.call_count == 1 + assert auth.refresh.call_count == 1 + + +def test_mount_adapter(): + """Test that mount method allows mounting adapters.""" + import requests.adapters + + session = ClientSession() + adapter = requests.adapters.HTTPAdapter(max_retries=3) + + # Test that mount method works without error + session.mount("http://", adapter) + session.mount("https://", adapter) + + # Verify the adapter was mounted by checking internal session + assert session._client.get_adapter("http://example.com") == adapter + assert session._client.get_adapter("https://example.com") == adapter + + session.close() diff --git a/tests/aio/test_aio_client.py b/tests/aio/test_aio_client.py index 3bb4a7f..888c02e 100644 --- a/tests/aio/test_aio_client.py +++ b/tests/aio/test_aio_client.py @@ -27,7 +27,7 @@ from reportportal_client import OutputType # noinspection PyProtectedMember -from reportportal_client._internal.aio.http import DEFAULT_RETRY_NUMBER, RetryingClientSession +from reportportal_client._internal.aio.http import DEFAULT_RETRY_NUMBER, ClientSession, RetryingClientSession # noinspection PyProtectedMember from reportportal_client._internal.static.defines import NOT_SET @@ -55,7 +55,7 @@ def test_client_pickling(): @pytest.mark.parametrize( - "retry_num, expected_class, expected_param", + "retry_num, expected_wrapped_class, expected_param", [ (1, RetryingClientSession, 1), (0, aiohttp.ClientSession, NOT_SET), @@ -65,12 +65,17 @@ def test_client_pickling(): ], ) @pytest.mark.asyncio -async def test_retries_param(retry_num, expected_class, expected_param): +async def test_retries_param(retry_num, expected_wrapped_class, expected_param): client = Client(ENDPOINT, PROJECT, api_key=API_KEY, retries=retry_num) session = await client.session() - assert isinstance(session, expected_class) + # Session is now a ClientSession wrapper + assert isinstance(session, ClientSession) + # Check the wrapped session type + # noinspection PyProtectedMember + assert isinstance(session._client, expected_wrapped_class) if expected_param is not NOT_SET: - assert getattr(session, "_RetryingClientSession__retry_number") == expected_param + # noinspection PyProtectedMember + assert getattr(session._client, "_RetryingClientSession__retry_number") == expected_param @pytest.mark.parametrize( @@ -86,13 +91,12 @@ async def test_timeout_param(mocked_session, timeout_param, expected_connect_par assert len(mocked_session.call_args_list) == 1 args, kwargs = mocked_session.call_args_list[0] assert len(args) == 1 and args[0] == ENDPOINT - expected_kwargs_keys = ["headers", "connector"] + expected_kwargs_keys = ["connector"] if timeout_param: expected_kwargs_keys.append("timeout") for key in expected_kwargs_keys: assert key in kwargs assert len(expected_kwargs_keys) == len(kwargs) - assert kwargs["headers"] == {"Authorization": f"Bearer {API_KEY}"} assert kwargs["connector"] is not None if timeout_param: assert kwargs["timeout"] is not None @@ -133,6 +137,149 @@ def test_clone(): ) +@mock.patch("reportportal_client.aio.client.warnings.warn") +def test_deprecated_token_argument(warn): + """Test that deprecated token argument works and issues a warning.""" + api_key = "api_key" + client = Client(endpoint="http://endpoint", project="project", token=api_key) + + assert warn.call_count == 1 + assert client.api_key == api_key + + +@mock.patch("reportportal_client.aio.client.warnings.warn") +def test_api_key_argument(warn): + """Test that normal api_key argument works without warning.""" + api_key = "api_key" + client = Client(endpoint="http://endpoint", project="project", api_key=api_key) + + assert warn.call_count == 0 + assert client.api_key == api_key + + +def test_empty_api_key_argument(): + """Test that empty api_key raises ValueError.""" + api_key = "" + with pytest.raises(ValueError) as exc_info: + Client(endpoint="http://endpoint", project="project", api_key=api_key) + + assert "Authentication credentials are required" in str(exc_info.value) + + +def test_oauth_authentication_parameters(): + """Test that OAuth 2.0 authentication parameters work correctly.""" + client = Client( + endpoint="http://endpoint", + project="project", + oauth_uri="https://example.com/oauth/token", + oauth_username="test_user", + oauth_password="test_password", + oauth_client_id="test_client_id", + oauth_client_secret="test_client_secret", + oauth_scope="read write", + ) + + assert client is not None + assert client.oauth_uri == "https://example.com/oauth/token" + assert client.oauth_username == "test_user" + assert client.oauth_password == "test_password" + assert client.oauth_client_id == "test_client_id" + assert client.oauth_client_secret == "test_client_secret" + assert client.oauth_scope == "read write" + assert client.api_key is None + + +def test_oauth_authentication_without_optional_parameters(): + """Test OAuth authentication with only required parameters.""" + client = Client( + endpoint="http://endpoint", + project="project", + oauth_uri="https://example.com/oauth/token", + oauth_username="test_user", + oauth_password="test_password", + oauth_client_id="test_client_id", + ) + + assert client is not None + assert client.oauth_uri == "https://example.com/oauth/token" + assert client.oauth_username == "test_user" + assert client.oauth_password == "test_password" + assert client.oauth_client_id == "test_client_id" + assert client.oauth_client_secret is None + assert client.oauth_scope is None + assert client.api_key is None + + +def test_no_authentication_parameters(): + """Test that missing authentication parameters raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + Client(endpoint="http://endpoint", project="project") + + assert "Authentication credentials are required" in str(exc_info.value) + assert "OAuth 2.0 parameters" in str(exc_info.value) + assert "api_key parameter" in str(exc_info.value) + + +def test_partial_oauth_parameters(): + """Test that missing authentication parameters raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + Client( + endpoint="http://endpoint", + project="project", + oauth_uri="https://example.com/oauth/token", + oauth_username="test_user", + oauth_password="test_password", + ) + + assert "Authentication credentials are required" in str(exc_info.value) + assert "OAuth 2.0 parameters" in str(exc_info.value) + assert "api_key parameter" in str(exc_info.value) + + +def test_clone_with_oauth(): + """Test cloning a client with OAuth authentication.""" + args = ["http://endpoint", "project"] + kwargs = { + "oauth_uri": "https://example.com/oauth/token", + "oauth_username": "test_user", + "oauth_password": "test_password", + "oauth_client_id": "test_client_id", + "oauth_client_secret": "test_secret", + "oauth_scope": "read write", + "is_skipped_an_issue": False, + "verify_ssl": False, + "retries": 5, + "max_pool_size": 30, + "http_timeout": (30, 30), + "keepalive_timeout": 25, + "mode": "DEBUG", + "launch_uuid_print": True, + "print_output": OutputType.STDERR, + } + client = Client(*args, **kwargs) + cloned = client.clone() + + assert cloned is not None and client is not cloned + assert cloned.endpoint == args[0] and cloned.project == args[1] + assert ( + cloned.oauth_uri == kwargs["oauth_uri"] + and cloned.oauth_username == kwargs["oauth_username"] + and cloned.oauth_password == kwargs["oauth_password"] + and cloned.oauth_client_id == kwargs["oauth_client_id"] + and cloned.oauth_client_secret == kwargs["oauth_client_secret"] + and cloned.oauth_scope == kwargs["oauth_scope"] + and cloned.is_skipped_an_issue == kwargs["is_skipped_an_issue"] + and cloned.verify_ssl == kwargs["verify_ssl"] + and cloned.retries == kwargs["retries"] + and cloned.max_pool_size == kwargs["max_pool_size"] + and cloned.http_timeout == kwargs["http_timeout"] + and cloned.keepalive_timeout == kwargs["keepalive_timeout"] + and cloned.mode == kwargs["mode"] + and cloned.launch_uuid_print == kwargs["launch_uuid_print"] + and cloned.print_output == kwargs["print_output"] + ) + + LAUNCH_ID = 333 EXPECTED_DEFAULT_URL = f"http://endpoint/ui/#project/launches/all/{LAUNCH_ID}" EXPECTED_DEBUG_URL = f"http://endpoint/ui/#project/userdebug/all/{LAUNCH_ID}" @@ -770,3 +917,85 @@ async def test_attribute_truncation(aio_client: Client, method, mock_method, cal assert "attributes" in kwargs["json"] assert kwargs["json"]["attributes"] assert len(kwargs["json"]["attributes"][0]["value"]) == 128 + + +@pytest.mark.asyncio +async def test_api_key_authorization_header(): + """Test that API key authentication sets Authorization header correctly.""" + api_key = "test_api_key_12345" + client = Client(endpoint=ENDPOINT, project=PROJECT, api_key=api_key) + + # Get the session (which is ClientSession wrapper) + session = await client.session() + + # Mock the underlying aiohttp.ClientSession within ClientSession + # noinspection PyProtectedMember + underlying_session_mock = mock.AsyncMock() + + # Mock response with status attribute + response_mock = mock.Mock() + response_mock.status = 200 + response_mock.json = mock.AsyncMock(return_value=RETURN_GET_JSON) + underlying_session_mock.get.return_value = response_mock + + # noinspection PyProtectedMember + session._client = underlying_session_mock + client._skip_analytics = "1" + + # Make a request + await client.get_project_settings() + + # Verify the underlying session.get was called + underlying_session_mock.get.assert_called_once() + call_kwargs = underlying_session_mock.get.call_args_list[0][1] + + # Verify Authorization header is set correctly + assert "headers" in call_kwargs + assert "Authorization" in call_kwargs["headers"] + assert call_kwargs["headers"]["Authorization"] == f"Bearer {api_key}" + + +@pytest.mark.asyncio +async def test_oauth_authorization_header(): + """Test that OAuth authentication sets Authorization header correctly.""" + client = Client( + endpoint=ENDPOINT, + project=PROJECT, + oauth_uri="https://example.com/oauth/token", + oauth_username="test_user", + oauth_password="test_password", + oauth_client_id="test_client_id", + ) + + # Get the session (which is ClientSession wrapper) + session = await client.session() + + # Mock the underlying aiohttp.ClientSession within ClientSession + # noinspection PyProtectedMember + underlying_session_mock = mock.AsyncMock() + + # Mock response with status attribute + response_mock = mock.Mock() + response_mock.status = 200 + response_mock.json = mock.AsyncMock(return_value=RETURN_GET_JSON) + underlying_session_mock.get.return_value = response_mock + + # noinspection PyProtectedMember + session._client = underlying_session_mock + client._skip_analytics = "1" + + # Mock the Auth.get() method to return a test token + test_token = "test_oauth_token_xyz" + client.auth._access_token = test_token + with mock.patch.object(client.auth, "_is_token_expired", return_value=False): + # Make a request + await client.get_project_settings() + + # Verify the underlying session.get was called + underlying_session_mock.get.assert_called_once() + call_kwargs = underlying_session_mock.get.call_args_list[0][1] + + # Verify Authorization header is set correctly + assert "headers" in call_kwargs + assert "Authorization" in call_kwargs["headers"] + assert call_kwargs["headers"]["Authorization"] == f"Bearer {test_token}" diff --git a/tests/conftest.py b/tests/conftest.py index 5000882..9d7c6af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,6 +23,9 @@ class DummyResponse: + def __init__(self): + self.status_code = 200 + # noinspection PyMethodMayBeStatic def json(self): return { diff --git a/tests/test_client.py b/tests/test_client.py index c78e125..347feef 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -179,13 +179,13 @@ def test_api_key_argument(warn): assert client.api_key == api_key -@mock.patch("reportportal_client.client.warnings.warn") -def test_empty_api_key_argument(warn): +def test_empty_api_key_argument(): + """Test that empty api_key raises ValueError.""" api_key = "" - client = RPClient(endpoint="http://endpoint", project="project", api_key=api_key) + with pytest.raises(ValueError) as exc_info: + RPClient(endpoint="http://endpoint", project="project", api_key=api_key) - assert warn.call_count == 1 - assert client.api_key == api_key + assert "Authentication credentials are required" in str(exc_info.value) def test_launch_uuid_print(rp_client): @@ -316,3 +316,187 @@ def test_logs_flush_on_close(rp_client: RPClient): batcher.flush.assert_called_once() session.post.assert_called_once() session.close.assert_called_once() + + +def test_oauth_authentication_parameters(): + """Test that OAuth 2.0 authentication parameters work correctly.""" + client = RPClient( + endpoint="http://endpoint", + project="project", + oauth_uri="https://example.com/oauth/token", + oauth_username="test_user", + oauth_password="test_password", + oauth_client_id="test_client_id", + oauth_client_secret="test_client_secret", + oauth_scope="read write", + ) + + assert client is not None + assert client.oauth_uri == "https://example.com/oauth/token" + assert client.oauth_username == "test_user" + assert client.oauth_password == "test_password" + assert client.oauth_client_id == "test_client_id" + assert client.oauth_client_secret == "test_client_secret" + assert client.oauth_scope == "read write" + assert client.api_key is None + + +def test_oauth_authentication_without_optional_parameters(): + """Test OAuth authentication with only required parameters.""" + client = RPClient( + endpoint="http://endpoint", + project="project", + oauth_uri="https://example.com/oauth/token", + oauth_username="test_user", + oauth_password="test_password", + oauth_client_id="test_client_id", + ) + + assert client is not None + assert client.oauth_uri == "https://example.com/oauth/token" + assert client.oauth_username == "test_user" + assert client.oauth_password == "test_password" + assert client.oauth_client_id == "test_client_id" + assert client.oauth_client_secret is None + assert client.oauth_scope is None + assert client.api_key is None + + +def test_no_authentication_parameters(): + """Test that missing authentication parameters raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + RPClient(endpoint="http://endpoint", project="project") + + assert "Authentication credentials are required" in str(exc_info.value) + assert "OAuth 2.0 parameters" in str(exc_info.value) + assert "api_key parameter" in str(exc_info.value) + + +def test_partial_oauth_parameters(): + """Test that missing authentication parameters raises ValueError.""" + with pytest.raises(ValueError) as exc_info: + RPClient( + endpoint="http://endpoint", + project="project", + oauth_uri="https://example.com/oauth/token", + oauth_username="test_user", + oauth_password="test_password", + ) + + assert "Authentication credentials are required" in str(exc_info.value) + assert "OAuth 2.0 parameters" in str(exc_info.value) + assert "api_key parameter" in str(exc_info.value) + + +def test_clone_with_oauth(): + """Test cloning a client with OAuth authentication.""" + args = ["http://endpoint", "project"] + kwargs = { + "oauth_uri": "https://example.com/oauth/token", + "oauth_username": "test_user", + "oauth_password": "test_password", + "oauth_client_id": "test_client_id", + "oauth_client_secret": "test_secret", + "oauth_scope": "read write", + "log_batch_size": 30, + "is_skipped_an_issue": False, + "verify_ssl": False, + "retries": 5, + "max_pool_size": 30, + "launch_id": "test-123", + "http_timeout": (30, 30), + "log_batch_payload_size": 1000000, + "mode": "DEBUG", + } + client = RPClient(*args, **kwargs) + client._add_current_item("test-321") + client._add_current_item("test-322") + cloned = client.clone() + + assert cloned is not None and client is not cloned + assert cloned.endpoint == args[0] and cloned.project == args[1] + assert ( + cloned.oauth_uri == kwargs["oauth_uri"] + and cloned.oauth_username == kwargs["oauth_username"] + and cloned.oauth_password == kwargs["oauth_password"] + and cloned.oauth_client_id == kwargs["oauth_client_id"] + and cloned.oauth_client_secret == kwargs["oauth_client_secret"] + and cloned.oauth_scope == kwargs["oauth_scope"] + and cloned.log_batch_size == kwargs["log_batch_size"] + and cloned.is_skipped_an_issue == kwargs["is_skipped_an_issue"] + and cloned.verify_ssl == kwargs["verify_ssl"] + and cloned.retries == kwargs["retries"] + and cloned.max_pool_size == kwargs["max_pool_size"] + and cloned.launch_uuid == kwargs["launch_id"] + and cloned.launch_id == kwargs["launch_id"] + and cloned.http_timeout == kwargs["http_timeout"] + and cloned.log_batch_payload_size == kwargs["log_batch_payload_size"] + and cloned.mode == kwargs["mode"] + ) + assert cloned._item_stack.qsize() == 1 and client.current_item() == cloned.current_item() + + +def test_api_key_authorization_header(): + """Test that API key authentication sets Authorization header correctly.""" + api_key = "test_api_key_12345" + client = RPClient(endpoint="http://endpoint", project="project", api_key=api_key) + + # Mock the underlying requests.Session within ClientSession + # noinspection PyProtectedMember + underlying_session_mock = mock.Mock() + underlying_session_mock.get.return_value = DummyResponse() + underlying_session_mock.post.return_value = DummyResponse() + underlying_session_mock.put.return_value = DummyResponse() + # noinspection PyProtectedMember + client.session._client = underlying_session_mock + client._skip_analytics = "1" + + # Make a request + client.get_project_settings() + + # Verify the underlying session.get was called + underlying_session_mock.get.assert_called_once() + call_kwargs = underlying_session_mock.get.call_args_list[0][1] + + # Verify Authorization header is set correctly + assert "headers" in call_kwargs + assert "Authorization" in call_kwargs["headers"] + assert call_kwargs["headers"]["Authorization"] == f"Bearer {api_key}" + + +def test_oauth_authorization_header(): + """Test that OAuth authentication sets Authorization header correctly.""" + client = RPClient( + endpoint="http://endpoint", + project="project", + oauth_uri="https://example.com/oauth/token", + oauth_username="test_user", + oauth_password="test_password", + oauth_client_id="test_client_id", + ) + + # Mock the underlying requests.Session within ClientSession + # noinspection PyProtectedMember + underlying_session_mock = mock.Mock() + underlying_session_mock.get.return_value = DummyResponse() + underlying_session_mock.post.return_value = DummyResponse() + underlying_session_mock.put.return_value = DummyResponse() + # noinspection PyProtectedMember + client.session._client = underlying_session_mock + client._skip_analytics = "1" + + # Mock the Auth.get() method to return a test token + test_token = "test_oauth_token_xyz" + client.auth._access_token = test_token + with mock.patch.object(client.auth, "_is_token_expired", return_value=False): + # Make a request + client.get_project_settings() + + # Verify the underlying session.get was called + underlying_session_mock.get.assert_called_once() + call_kwargs = underlying_session_mock.get.call_args_list[0][1] + + # Verify Authorization header is set correctly + assert "headers" in call_kwargs + assert "Authorization" in call_kwargs["headers"] + assert call_kwargs["headers"]["Authorization"] == f"Bearer {test_token}" diff --git a/tests/test_client_factory.py b/tests/test_client_factory.py index 850ce0f..3f9c354 100644 --- a/tests/test_client_factory.py +++ b/tests/test_client_factory.py @@ -27,5 +27,5 @@ ], ) def test_client_factory_types(requested_type: ClientType, expected_type): - result = create_client(requested_type, "http://endpoint", "default_personal") + result = create_client(requested_type, "http://endpoint", "default_personal", api_key="test_api_key") assert isinstance(result, expected_type)