Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python retry mechanism #470

Merged
merged 12 commits into from
Sep 7, 2018
3 changes: 2 additions & 1 deletion stripe/api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import platform
import time
import uuid

import stripe
from stripe import error, oauth_error, http_client, version, util, six
Expand Down Expand Up @@ -220,6 +221,7 @@ def request_headers(self, api_key, method):

if method == 'post':
headers['Content-Type'] = 'application/x-www-form-urlencoded'
headers.setdefault('Idempotency-Key', str(uuid.uuid4()))

if self.api_version is not None:
headers['Stripe-Version'] = self.api_version
Expand Down Expand Up @@ -271,7 +273,6 @@ def request_raw(self, method, url, params=None, supplied_headers=None):
'assistance.' % (method,))

headers = self.request_headers(my_api_key, method)

if supplied_headers is not None:
for key, value in six.iteritems(supplied_headers):
headers[key] = value
Expand Down
7 changes: 6 additions & 1 deletion stripe/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,12 @@ class APIError(StripeError):


class APIConnectionError(StripeError):
pass
def __init__(self, message, http_body=None, http_status=None,
json_body=None, headers=None, code=None, should_retry=False):
super(APIConnectionError, self).__init__(message, http_body,
http_status,
json_body, headers, code)
self.should_retry = should_retry
ob-stripe marked this conversation as resolved.
Show resolved Hide resolved


class StripeErrorWithParamCode(StripeError):
Expand Down
93 changes: 87 additions & 6 deletions stripe/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import textwrap
import warnings
import email
import time
import random

from stripe import error, util, six

Expand Down Expand Up @@ -77,6 +79,10 @@ def new_default_http_client(*args, **kwargs):


class HTTPClient(object):
MAX_RETRIES = 3
MAX_DELAY = 2
INITIAL_DELAY = 0.5
ob-stripe marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, verify_ssl_certs=True, proxy=None):
self._verify_ssl_certs = verify_ssl_certs
if proxy:
Expand All @@ -90,9 +96,62 @@ def __init__(self, verify_ssl_certs=True, proxy=None):
self._proxy = proxy.copy() if proxy else None

def request(self, method, url, headers, post_data=None):
ob-stripe marked this conversation as resolved.
Show resolved Hide resolved
num_retries = 0

while True:
try:
num_retries += 1
response = self._delegated_request(method, url, headers,
post_data)
connection_error = None
except error.APIConnectionError as e:
connection_error = e
response = None

if self._should_retry(response, connection_error, num_retries):
self._sleep(num_retries)
else:
if response is not None:
return response
else:
raise connection_error
ob-stripe marked this conversation as resolved.
Show resolved Hide resolved

def _delegated_request(self, method, url, headers, post_data=None):
raise NotImplementedError(
'HTTPClient subclasses must implement `request`')

def _should_retry(self, response, api_connection_error, num_retries):
if response is not None:
_, status_code, _ = response
should_retry = status_code == 409
else:
should_retry = api_connection_error.should_retry
ob-stripe marked this conversation as resolved.
Show resolved Hide resolved
return should_retry and num_retries < HTTPClient.MAX_RETRIES

def _sleep(self, num_retries):
time.sleep(self._sleep_time(num_retries))

def _sleep_time(self, num_retries):
# Apply exponential backoff with initial_network_retry_delay on the
# number of num_retries so far as inputs.
# Do not allow the number to exceed max_network_retry_delay.
sleep_seconds = min(
HTTPClient.INITIAL_DELAY * (2 ** (num_retries - 1)),
HTTPClient.MAX_DELAY)

sleep_seconds = self._add_jitter_time(sleep_seconds)

# But never sleep less than the base sleep seconds.
sleep_seconds = max(HTTPClient.INITIAL_DELAY, sleep_seconds)

return sleep_seconds

def _add_jitter_time(self, sleep_seconds):
# Randomize the value in [(sleep_seconds/ 2) to (sleep_seconds)]
# Also separated method here to isolate randomness for tests
ob-stripe marked this conversation as resolved.
Show resolved Hide resolved
sleep_seconds *= (0.5 * (1 + random.uniform(0, 1)))
return sleep_seconds

def close(self):
raise NotImplementedError(
'HTTPClient subclasses must implement `close`')
Expand All @@ -106,7 +165,7 @@ def __init__(self, timeout=80, session=None, **kwargs):
self._timeout = timeout
self._session = session or requests.Session()

def request(self, method, url, headers, post_data=None):
def _delegated_request(self, method, url, headers, post_data=None):
kwargs = {}
if self._verify_ssl_certs:
kwargs['verify'] = os.path.join(
Expand Down Expand Up @@ -146,11 +205,31 @@ def request(self, method, url, headers, post_data=None):
return content, status_code, result.headers

def _handle_request_error(self, e):
if isinstance(e, requests.exceptions.RequestException):

# Catch SSL error first as it belongs to ConnectionError,
# but we don't want to retry
if isinstance(e, requests.exceptions.SSLError):
ob-stripe marked this conversation as resolved.
Show resolved Hide resolved
msg = ("Could not verify Stripe's SSL certificate. Please make "
"sure that your network is not intercepting certificates. "
"If this problem persists, let us know at "
"support@stripe.com.")
err = "%s: %s" % (type(e).__name__, str(e))
should_retry = False
# Retry only timeout and connect errors; similar to urllib3 Retry
elif isinstance(e, requests.exceptions.Timeout) or \
isinstance(e, requests.exceptions.ConnectionError):
msg = ("Unexpected error communicating with Stripe. "
"If this problem persists, let us know at "
"support@stripe.com.")
err = "%s: %s" % (type(e).__name__, str(e))
should_retry = True
# Catch remaining request exceptions
elif isinstance(e, requests.exceptions.RequestException):
msg = ("Unexpected error communicating with Stripe. "
"If this problem persists, let us know at "
"support@stripe.com.")
err = "%s: %s" % (type(e).__name__, str(e))
should_retry = False
else:
msg = ("Unexpected error communicating with Stripe. "
"It looks like there's probably a configuration "
Expand All @@ -161,8 +240,10 @@ def _handle_request_error(self, e):
err += " with error message %s" % (str(e),)
else:
err += " with no error message"
should_retry = False

msg = textwrap.fill(msg) + "\n\n(Network error: %s)" % (err,)
raise error.APIConnectionError(msg)
raise error.APIConnectionError(msg, should_retry=should_retry)

def close(self):
if self._session is not None:
Expand Down Expand Up @@ -190,7 +271,7 @@ def __init__(self, verify_ssl_certs=True, proxy=None, deadline=55):
# to 55 seconds to allow for a slow Stripe
self._deadline = deadline

def request(self, method, url, headers, post_data=None):
def _delegated_request(self, method, url, headers, post_data=None):
try:
result = urlfetch.fetch(
url=url,
Expand Down Expand Up @@ -256,7 +337,7 @@ def parse_headers(self, data):
headers = email.message_from_string(raw_headers)
return dict((k.lower(), v) for k, v in six.iteritems(dict(headers)))

def request(self, method, url, headers, post_data=None):
def _delegated_request(self, method, url, headers, post_data=None):
b = util.io.BytesIO()
rheaders = util.io.BytesIO()

Expand Down Expand Up @@ -365,7 +446,7 @@ def __init__(self, verify_ssl_certs=True, proxy=None):
proxy = urllib.request.ProxyHandler(self._proxy)
self._opener = urllib.request.build_opener(proxy)

def request(self, method, url, headers, post_data=None):
def _delegated_request(self, method, url, headers, post_data=None):
if six.PY3 and isinstance(post_data, six.string_types):
post_data = post_data.encode('utf-8')

Expand Down
56 changes: 52 additions & 4 deletions tests/test_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import datetime
import json
import tempfile
import uuid

import pytest

Expand Down Expand Up @@ -35,28 +36,32 @@ class APIHeaderMatcher(object):
'User-Agent',
'X-Stripe-Client-User-Agent',
]
METHOD_EXTRA_KEYS = {"post": ["Content-Type"]}
METHOD_EXTRA_KEYS = {"post": ["Content-Type", "Idempotency-Key"]}

def __init__(self, api_key=None, extra={}, request_method=None,
user_agent=None, app_info=None):
user_agent=None, app_info=None, idempotency_key=None):
self.request_method = request_method
self.api_key = api_key or stripe.api_key
self.extra = extra
self.user_agent = user_agent
self.app_info = app_info
self.idempotency_key = idempotency_key

def __eq__(self, other):
return (self._keys_match(other) and
self._auth_match(other) and
self._user_agent_match(other) and
self._x_stripe_ua_contains_app_info(other) and
self._idempotency_key_match(other) and
self._extra_match(other))

def __repr__(self):
return ("APIHeaderMatcher(request_method=%s, api_key=%s, extra=%s, "
"user_agent=%s, app_info=%s)" %
"user_agent=%s, app_info=%s, idempotency_key=%s)" %
(repr(self.request_method), repr(self.api_key),
repr(self.extra), repr(self.user_agent), repr(self.app_info)))
repr(self.extra), repr(self.user_agent), repr(self.app_info),
repr(self.idempotency_key))
)

def _keys_match(self, other):
expected_keys = list(set(self.EXP_KEYS + list(self.extra.keys())))
Expand All @@ -74,6 +79,11 @@ def _user_agent_match(self, other):

return True

def _idempotency_key_match(self, other):
if self.idempotency_key is not None:
return other['Idempotency-Key'] == self.idempotency_key
return True

def _x_stripe_ua_contains_app_info(self, other):
if self.app_info:
ua = json.loads(other['X-Stripe-Client-User-Agent'])
Expand Down Expand Up @@ -129,6 +139,19 @@ def __repr__(self):
return ("UrlMatcher(exp_parts=%s)" % (repr(self.exp_parts)))


class AnyUUID4Matcher(object):

def __eq__(self, other):
try:
uuid.UUID(other, version=4)
except ValueError:
return False
return True

def __repr__(self):
return "AnyUUID4Matcher()"


class TestAPIRequestor(object):
ENCODE_INPUTS = {
'dict': {
Expand Down Expand Up @@ -417,6 +440,31 @@ def test_uses_app_info(self, requestor, mock_response, check_call):
finally:
stripe.app_info = old

def test_uses_given_idempotency_key(self, requestor, mock_response,
check_call):
mock_response('{}', 200)
meth = 'post'
requestor.request(meth, self.valid_path, {},
{'Idempotency-Key': '123abc'})

header_matcher = APIHeaderMatcher(
request_method=meth,
idempotency_key='123abc'
)
check_call(meth, headers=header_matcher, post_data='')

def test_uuid4_idempotency_key_when_not_given(self, requestor,
mock_response, check_call):
mock_response('{}', 200)
meth = 'post'
requestor.request(meth, self.valid_path, {})

header_matcher = APIHeaderMatcher(
request_method=meth,
idempotency_key=AnyUUID4Matcher()
)
check_call(meth, headers=header_matcher, post_data='')

def test_fails_without_api_key(self, requestor):
stripe.api_key = None

Expand Down
9 changes: 9 additions & 0 deletions tests/test_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,12 @@ def test_repr(self):
assert repr(err) == \
"CardError(message='öre', param='cparam', code='ccode', " \
"http_status=403, request_id='123')"


class TestApiConnectionError(object):
def test_default_no_retry(self):
err = error.APIConnectionError('msg')
assert err.should_retry is False

err = error.APIConnectionError('msg', should_retry=True)
assert err.should_retry
Loading