diff --git a/cloudpub/ms_azure/session.py b/cloudpub/ms_azure/session.py index 7712f80..def77d2 100644 --- a/cloudpub/ms_azure/session.py +++ b/cloudpub/ms_azure/session.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: GPL-3.0-or-later import logging +import os from datetime import datetime from typing import Any, Dict, Optional @@ -10,6 +11,8 @@ log = logging.getLogger(__name__) +AZURE_SESSION_TIMEOUT: float = float(os.environ.get("AZURE_SESSION_TIMEOUT", 5.0)) + class AccessToken: """Represent the Microsoft API Authorization token.""" @@ -131,7 +134,7 @@ def _login(self) -> AccessToken: "grant_type": "client_credentials", } - resp = self.session.post(url, headers=headers, data=data, timeout=30) + resp = self.session.post(url, headers=headers, data=data, timeout=AZURE_SESSION_TIMEOUT) resp.raise_for_status() return AccessToken(resp.json()) @@ -159,7 +162,10 @@ def _request( log.debug("Sending a %s request to %s", method, path) formatted_url = self._prefix_url.format(**self.auth_keys) url = join_url(formatted_url, path) - return self.session.request(method, url=url, params=params, headers=headers, **kwargs) + timeout = kwargs.pop("timeout", AZURE_SESSION_TIMEOUT) + return self.session.request( + method, url=url, params=params, headers=headers, timeout=timeout, **kwargs + ) def get(self, path: str, **kwargs: Any) -> requests.Response: """Execute an API GET request.""" diff --git a/tests/ms_azure/test_session.py b/tests/ms_azure/test_session.py index 996f6c7..bd5a489 100644 --- a/tests/ms_azure/test_session.py +++ b/tests/ms_azure/test_session.py @@ -1,3 +1,5 @@ +import importlib +import os from datetime import datetime, timedelta from typing import Any, Dict from unittest import mock @@ -5,7 +7,7 @@ import pytest from httmock import response -from cloudpub.ms_azure.session import AccessToken, PartnerPortalSession +from cloudpub.ms_azure.session import AZURE_SESSION_TIMEOUT, AccessToken, PartnerPortalSession from cloudpub.utils import join_url @@ -72,7 +74,7 @@ def test_login( session_mock.return_value.request.assert_called_once() session_mock.return_value.post.assert_called_once_with( - login_url, headers=login_header, data=login_data, timeout=30 + login_url, headers=login_header, data=login_data, timeout=AZURE_SESSION_TIMEOUT ) @pytest.mark.parametrize( @@ -110,10 +112,124 @@ def test_request( if json: getattr(session, method)(path, json) mock_session.return_value.request.assert_called_once_with( - method, url=url, params=put_param, headers=put_headers, json={"foo": "bar"} + method, + url=url, + params=put_param, + headers=put_headers, + json={"foo": "bar"}, + timeout=AZURE_SESSION_TIMEOUT, ) else: getattr(session, method)(path) mock_session.return_value.request.assert_called_once_with( - method, url=url, params=put_param, headers=put_headers + method, + url=url, + params=put_param, + headers=put_headers, + timeout=AZURE_SESSION_TIMEOUT, + ) + + @mock.patch("cloudpub.ms_azure.session.requests.Session") + def test_azure_session_timeout_from_env( + self, + session_mock: mock.MagicMock, + auth_dict: Dict[str, str], + token: Dict[str, str], + ) -> None: + """AZURE_SESSION_TIMEOUT from the environment is used for login and API calls.""" + import cloudpub.ms_azure.session as session_mod + + env_timeout = "37.5" + try: + with mock.patch.dict(os.environ, {"AZURE_SESSION_TIMEOUT": env_timeout}, clear=False): + importlib.reload(session_mod) + assert session_mod.AZURE_SESSION_TIMEOUT == float(env_timeout) + + session_mock.return_value.post.return_value = response(200, token) + + tenant = auth_dict['AZURE_TENANT_ID'] + login_url = f"https://login.microsoftonline.com/{tenant}/oauth2/token" + login_header = {"Accept": "application/json"} + login_data = { + "resource": "https://graph.microsoft.com", + "client_id": auth_dict["AZURE_CLIENT_ID"], + "client_secret": auth_dict["AZURE_API_SECRET"], + "grant_type": "client_credentials", + } + + session = session_mod.PartnerPortalSession.make_graph_api_session( + auth_dict, schema_version=auth_dict['AZURE_SCHEMA_VERSION'] + ) + session.get("/foo") + + session_mock.return_value.post.assert_called_once_with( + login_url, + headers=login_header, + data=login_data, + timeout=float(env_timeout), + ) + session_mock.return_value.request.assert_called_once_with( + "get", + url=join_url("https://graph.microsoft.com/rp/product-ingestion", "foo"), + params={'$version': auth_dict['AZURE_SCHEMA_VERSION']}, + headers={ + 'Accept': 'application/json', + 'Authorization': f'Bearer {token["access_token"]}', + }, + timeout=float(env_timeout), + ) + finally: + importlib.reload(session_mod) + + @pytest.mark.parametrize( + 'method,path,body', + [ + ('get', 'foo', None), + ('post', "foo", {"foo": "bar"}), + ('put', "foo", {"foo": "bar"}), + ], + ) + @mock.patch("cloudpub.ms_azure.session.requests.Session") + def test_request_respects_explicit_timeout( + self, + mock_session: mock.MagicMock, + method: str, + path: str, + body: Dict[str, Any] | None, + auth_dict: Dict[str, str], + token: Dict[str, str], + ) -> None: + """Explicit timeout= on get/post/put overrides AZURE_SESSION_TIMEOUT.""" + mock_session.return_value.post.return_value = response(200, token) + + url = join_url("https://graph.microsoft.com/rp/product-ingestion", path) + put_headers = { + 'Accept': 'application/json', + 'Authorization': f'Bearer {token["access_token"]}', + } + put_param = {'$version': auth_dict['AZURE_SCHEMA_VERSION']} + explicit_timeout = 88.0 + + session = PartnerPortalSession.make_graph_api_session( + auth_dict, schema_version=auth_dict['AZURE_SCHEMA_VERSION'] + ) + + if body is not None: + getattr(session, method)(path, body, timeout=explicit_timeout) + mock_session.return_value.request.assert_called_once_with( + method, + url=url, + params=put_param, + headers=put_headers, + json={"foo": "bar"}, + timeout=explicit_timeout, + ) + else: + getattr(session, method)(path, timeout=explicit_timeout) + mock_session.return_value.request.assert_called_once_with( + method, + url=url, + params=put_param, + headers=put_headers, + timeout=explicit_timeout, )