diff --git a/sypht/client.py b/sypht/client.py index 849fa57..a709bf0 100644 --- a/sypht/client.py +++ b/sypht/client.py @@ -6,8 +6,10 @@ from urllib.parse import quote_plus, urlencode, urljoin import requests +from requests.adapters import HTTPAdapter +from urllib3.util import Retry -from .util import fetch_all_pages +from sypht.util import fetch_all_pages SYPHT_API_BASE_ENDPOINT = "https://api.sypht.com" SYPHT_AUTH_ENDPOINT = "https://auth.sypht.com/oauth2/token" @@ -73,9 +75,27 @@ def __init__( self._company_id = None self._authenticate_client() + @property + def _retry_adapter(self): + retry_strategy = Retry( + total=None, # set connect, read, redirect, status, other instead + connect=3, + read=3, + redirect=0, + status=3, + status_forcelist=[429, 502, 503, 504], + other=0, # catch-all for other errors + allowed_methods=["GET"], + respect_retry_after_header=False, + backoff_factor=0.5, # 0.0, 0.5, 1.0, 2.0, 4.0 + ) + return HTTPAdapter(max_retries=retry_strategy) + @property def _create_session(self): - return requests.Session() + session = requests.Session() + session.mount(self.base_endpoint, self._retry_adapter) + return session def _authenticate_v2(self, endpoint, client_id, client_secret, audience): basic_auth_slug = b64encode( diff --git a/tests/tests_client.py b/tests/tests_client.py index c01c9c2..a619448 100644 --- a/tests/tests_client.py +++ b/tests/tests_client.py @@ -1,7 +1,8 @@ -import os import unittest import warnings from datetime import datetime, timedelta +from http.client import HTTPMessage +from unittest.mock import ANY, Mock, call, patch from uuid import UUID, uuid4 from sypht.client import SyphtClient @@ -93,5 +94,66 @@ def test_reauthentication(self): self.assertFalse(self.sypht_client._is_token_expired()) +class RetryTest(unittest.TestCase): + """Test the global retry logic works as we expect it to.""" + + @patch.object(SyphtClient, "_authenticate_v2", return_value=("access_token", 100)) + @patch.object(SyphtClient, "_authenticate_v1", return_value=("access)token2", 100)) + @patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") + def test_it_should_eventually_fail_for_50x( + self, getconn_mock: Mock, auth_v1: Mock, auth_v2: Mock + ): + """See https://stackoverflow.com/questions/66497627/how-to-test-retry-attempts-in-python-using-the-request-library .""" + + # arrange + getconn_mock.return_value.getresponse.side_effect = [ + Mock(status=502, msg=HTTPMessage()), + # Retries start from here... + # There should be n for where Retry(status=n). + Mock(status=502, msg=HTTPMessage()), + Mock(status=503, msg=HTTPMessage()), + Mock(status=504, msg=HTTPMessage()), + ] + sypht_client = SyphtClient() + + # act / assert + with self.assertRaisesRegex(Exception, "Max retries exceeded") as e: + sypht_client.get_annotations( + from_date=datetime( + year=2021, month=1, day=1, hour=0, minute=0, second=0 + ).strftime("%Y-%m-%d"), + to_date=datetime( + year=2021, month=1, day=1, hour=0, minute=0, second=0 + ).strftime("%Y-%m-%d"), + ) + assert getconn_mock.return_value.request.mock_calls == [ + call( + "GET", + "/app/annotations?offset=0&fromDate=2021-01-01&toDate=2021-01-01", + body=None, + headers=ANY, + ), + # Retries start here... + call( + "GET", + "/app/annotations?offset=0&fromDate=2021-01-01&toDate=2021-01-01", + body=None, + headers=ANY, + ), + call( + "GET", + "/app/annotations?offset=0&fromDate=2021-01-01&toDate=2021-01-01", + body=None, + headers=ANY, + ), + call( + "GET", + "/app/annotations?offset=0&fromDate=2021-01-01&toDate=2021-01-01", + body=None, + headers=ANY, + ), + ] + + if __name__ == "__main__": unittest.main()