Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions cloudpub/ms_azure/session.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import logging
import os
from datetime import datetime
from typing import Any, Dict, Optional

Expand All @@ -10,6 +11,8 @@

log = logging.getLogger(__name__)

AZURE_SESSION_TIMEOUT: float = float(os.environ.get("AZURE_SESSION_TIMEOUT", 5.0))
Comment thread
JAVGan marked this conversation as resolved.


class AccessToken:
"""Represent the Microsoft API Authorization token."""
Expand Down Expand Up @@ -131,7 +134,7 @@ def _login(self) -> AccessToken:
"grant_type": "client_credentials",
}

resp = self.session.post(url, headers=headers, data=data, timeout=30)
resp = self.session.post(url, headers=headers, data=data, timeout=AZURE_SESSION_TIMEOUT)
Comment thread
lslebodn marked this conversation as resolved.
resp.raise_for_status()
return AccessToken(resp.json())

Expand Down Expand Up @@ -159,7 +162,10 @@ def _request(
log.debug("Sending a %s request to %s", method, path)
formatted_url = self._prefix_url.format(**self.auth_keys)
url = join_url(formatted_url, path)
return self.session.request(method, url=url, params=params, headers=headers, **kwargs)
timeout = kwargs.pop("timeout", AZURE_SESSION_TIMEOUT)
return self.session.request(
method, url=url, params=params, headers=headers, timeout=timeout, **kwargs
)

def get(self, path: str, **kwargs: Any) -> requests.Response:
"""Execute an API GET request."""
Expand Down
124 changes: 120 additions & 4 deletions tests/ms_azure/test_session.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import importlib
import os
from datetime import datetime, timedelta
from typing import Any, Dict
from unittest import mock

import pytest
from httmock import response

from cloudpub.ms_azure.session import AccessToken, PartnerPortalSession
from cloudpub.ms_azure.session import AZURE_SESSION_TIMEOUT, AccessToken, PartnerPortalSession
from cloudpub.utils import join_url


Expand Down Expand Up @@ -72,7 +74,7 @@ def test_login(

session_mock.return_value.request.assert_called_once()
session_mock.return_value.post.assert_called_once_with(
login_url, headers=login_header, data=login_data, timeout=30
login_url, headers=login_header, data=login_data, timeout=AZURE_SESSION_TIMEOUT
Comment thread
JAVGan marked this conversation as resolved.
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -110,10 +112,124 @@ def test_request(
if json:
getattr(session, method)(path, json)
mock_session.return_value.request.assert_called_once_with(
method, url=url, params=put_param, headers=put_headers, json={"foo": "bar"}
method,
url=url,
params=put_param,
headers=put_headers,
json={"foo": "bar"},
timeout=AZURE_SESSION_TIMEOUT,
Comment thread
JAVGan marked this conversation as resolved.
)
else:
getattr(session, method)(path)
mock_session.return_value.request.assert_called_once_with(
method, url=url, params=put_param, headers=put_headers
method,
url=url,
params=put_param,
headers=put_headers,
timeout=AZURE_SESSION_TIMEOUT,
)

@mock.patch("cloudpub.ms_azure.session.requests.Session")
def test_azure_session_timeout_from_env(
self,
session_mock: mock.MagicMock,
auth_dict: Dict[str, str],
token: Dict[str, str],
) -> None:
"""AZURE_SESSION_TIMEOUT from the environment is used for login and API calls."""
import cloudpub.ms_azure.session as session_mod

env_timeout = "37.5"
try:
with mock.patch.dict(os.environ, {"AZURE_SESSION_TIMEOUT": env_timeout}, clear=False):
importlib.reload(session_mod)
assert session_mod.AZURE_SESSION_TIMEOUT == float(env_timeout)

session_mock.return_value.post.return_value = response(200, token)

tenant = auth_dict['AZURE_TENANT_ID']
login_url = f"https://login.microsoftonline.com/{tenant}/oauth2/token"
login_header = {"Accept": "application/json"}
login_data = {
"resource": "https://graph.microsoft.com",
"client_id": auth_dict["AZURE_CLIENT_ID"],
"client_secret": auth_dict["AZURE_API_SECRET"],
"grant_type": "client_credentials",
}

session = session_mod.PartnerPortalSession.make_graph_api_session(
auth_dict, schema_version=auth_dict['AZURE_SCHEMA_VERSION']
)
session.get("/foo")

session_mock.return_value.post.assert_called_once_with(
login_url,
headers=login_header,
data=login_data,
timeout=float(env_timeout),
)
session_mock.return_value.request.assert_called_once_with(
"get",
url=join_url("https://graph.microsoft.com/rp/product-ingestion", "foo"),
params={'$version': auth_dict['AZURE_SCHEMA_VERSION']},
headers={
'Accept': 'application/json',
'Authorization': f'Bearer {token["access_token"]}',
},
timeout=float(env_timeout),
)
finally:
importlib.reload(session_mod)

@pytest.mark.parametrize(
'method,path,body',
[
('get', 'foo', None),
('post', "foo", {"foo": "bar"}),
('put', "foo", {"foo": "bar"}),
],
)
@mock.patch("cloudpub.ms_azure.session.requests.Session")
def test_request_respects_explicit_timeout(
self,
mock_session: mock.MagicMock,
method: str,
path: str,
body: Dict[str, Any] | None,
auth_dict: Dict[str, str],
token: Dict[str, str],
) -> None:
"""Explicit timeout= on get/post/put overrides AZURE_SESSION_TIMEOUT."""
mock_session.return_value.post.return_value = response(200, token)

url = join_url("https://graph.microsoft.com/rp/product-ingestion", path)
put_headers = {
'Accept': 'application/json',
'Authorization': f'Bearer {token["access_token"]}',
}
put_param = {'$version': auth_dict['AZURE_SCHEMA_VERSION']}
explicit_timeout = 88.0

session = PartnerPortalSession.make_graph_api_session(
auth_dict, schema_version=auth_dict['AZURE_SCHEMA_VERSION']
)

if body is not None:
getattr(session, method)(path, body, timeout=explicit_timeout)
mock_session.return_value.request.assert_called_once_with(
method,
url=url,
params=put_param,
headers=put_headers,
json={"foo": "bar"},
timeout=explicit_timeout,
)
else:
getattr(session, method)(path, timeout=explicit_timeout)
mock_session.return_value.request.assert_called_once_with(
method,
url=url,
params=put_param,
headers=put_headers,
timeout=explicit_timeout,
)
Loading