Skip to content

Commit

Permalink
Merge 6b202f5 into bd9fc7e
Browse files Browse the repository at this point in the history
  • Loading branch information
guilbep committed Mar 30, 2018
2 parents bd9fc7e + 6b202f5 commit cc9f351
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 21 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Expand Up @@ -2,7 +2,7 @@ colorama==0.3.7
formats==0.1.0a3
httpretty==0.8.3
pkginfo==1.1
requests==2.5.0
requests==2.13.0
six==1.8.0
twine==1.3.1
urllib3==1.7.1
22 changes: 11 additions & 11 deletions test_tortilla.py
Expand Up @@ -141,14 +141,14 @@ def test_cached_response(self):
self.assertEqual(self.api.cache.get(), "this should not be returned")

def test_request_delay(self):
self.api.config.delay = 0.5
self.api._config.delay = 0.5
self.api.test.get()
self.assertGreaterEqual(self._time_function(self.api.test.get), 0.5)
self.assertGreaterEqual(self._time_function(self.api.test.get, delay=0.1), 0.1)
self.assertGreaterEqual(self._time_function(self.api.test.get), 0.5)

# do not delay the rest of the tests
self.api.config.delay = 0
self.api._config.delay = 0

def test_request_methods(self):
self.assertEqual(self.api.awesome.tweet.post().message, "Success!")
Expand All @@ -163,19 +163,19 @@ def test_extensions(self):

def test_wrap_config(self):
self.api.stuff(debug=True, extension='json', cache_lifetime=5, silent=True)
self.assertTrue(self.api.stuff.config.debug)
self.assertEqual(self.api.stuff.config.extension, 'json')
self.assertEqual(self.api.stuff.config.cache_lifetime, 5)
self.assertTrue(self.api.stuff.config.silent)
self.assertTrue(self.api.stuff._config.debug)
self.assertEqual(self.api.stuff._config.extension, 'json')
self.assertEqual(self.api.stuff._config.cache_lifetime, 5)
self.assertTrue(self.api.stuff._config.silent)

self.api.stuff(debug=False, extension='xml', cache_lifetime=8, silent=False)
self.assertFalse(self.api.stuff.config.debug)
self.assertEqual(self.api.stuff.config.extension, 'xml')
self.assertEqual(self.api.stuff.config.cache_lifetime, 8)
self.assertFalse(self.api.stuff.config.silent)
self.assertFalse(self.api.stuff._config.debug)
self.assertEqual(self.api.stuff._config.extension, 'xml')
self.assertEqual(self.api.stuff._config.cache_lifetime, 8)
self.assertFalse(self.api.stuff._config.silent)

self.api.stuff('more', 'stuff', debug=True)
self.assertTrue(self.api.stuff.config.debug)
self.assertTrue(self.api.stuff._config.debug)

def test_wrap_chain(self):
self.assertIs(self.api.chained.wrap.stuff, self.api('chained').wrap('stuff'))
Expand Down
39 changes: 30 additions & 9 deletions tortilla/wrappers.py
Expand Up @@ -12,6 +12,12 @@
from .cache import CacheWrapper, DictCache
from .utils import formats, run_from_ipython, Bunch, bunchify

try:
import OpenSSL
ConnectionError = OpenSSL.SSL.SysCallError
except ImportError:
ConnectionError = requests.exceptions.ConnectionError


debug_messages = {
'request': ''.join([
Expand Down Expand Up @@ -66,13 +72,14 @@
class Client(object):
"""Wrapper around the most basic methods of the requests library."""

def __init__(self, debug=False, cache=None):
def __init__(self, debug=False, cache=None, **kwargs):
self.headers = Bunch()
self.debug = debug
self.cache = cache if cache else DictCache()
self.cache = CacheWrapper(self.cache)
self.session = requests.session()
self._last_request_time = None
self.defaults = kwargs

def _log(self, message, debug=None, **kwargs):
"""Outputs a formatted message in the console if the
Expand All @@ -89,6 +96,16 @@ def _log(self, message, debug=None, **kwargs):
if display_log:
print(message.format(**kwargs))

def send_request(self, *args, **kwargs):
"""Wrapper for session.request
Handle connection reset error even from pyopenssl
"""
try:
return self.session.request(*args, **kwargs)
except ConnectionError:
self.session.close()
return self.session.request(*args, **kwargs)

def request(self, method, url, path=(), extension=None, suffix=None,
params=None, headers=None, data=None, debug=None,
cache_lifetime=None, silent=None, ignore_cache=False,
Expand Down Expand Up @@ -186,9 +203,13 @@ def request(self, method, url, path=(), extension=None, suffix=None,
if elapsed < delay:
time.sleep(delay - elapsed)

# use default request parameters
for name, value in self.defaults.items():
kwargs.setdefault(name, value)

# execute the request
r = self.session.request(method, url, params=params,
headers=request_headers, data=data, **kwargs)
r = self.send_request(method, url, params=params,
headers=request_headers, data=data, **kwargs)
self._last_request_time = time.time()

# when not silent, raise an exception for any status code >= 400
Expand Down Expand Up @@ -247,13 +268,13 @@ class Wrap(object):
def __init__(self, part, parent=None, headers=None, params=None,
debug=None, cache_lifetime=None, silent=None,
extension=None, suffix=None, format=None, cache=None,
delay=None):
delay=None, **kwargs):
if not hasattr(part, "encode"):
part = str(part)
self._part = part[:-1] if part[-1:] == '/' else part
self._url = None
self._parent = parent or Client(debug=debug, cache=cache)
self.config = Bunch({
self._parent = parent or Client(debug=debug, cache=cache, **kwargs)
self._config = Bunch({
'headers': bunchify(headers) if headers else Bunch(),
'params': bunchify(params) if params else Bunch(),
'debug': debug,
Expand Down Expand Up @@ -298,7 +319,7 @@ def __call__(self, *parts, **options):
:param options: (optional) Arguments accepted by the
:class:`Wrap` initializer
"""
self.config.update(**options)
self._config.update(**options)

if len(parts) == 0:
return self
Expand All @@ -322,7 +343,7 @@ def __getattr__(self, part):
return self.__dict__[part]
except KeyError:
self.__dict__[part] = Wrap(part=part, parent=self,
debug=self.config.get('debug'))
debug=self._config.get('debug'))
return self.__dict__[part]

def request(self, method, *parts, **options):
Expand Down Expand Up @@ -356,7 +377,7 @@ def request(self, method, *parts, **options):
# the last part constructs the URL
options['url'] = self.url()

for key, value in six.iteritems(self.config):
for key, value in six.iteritems(self._config):
# set the defaults in the options
if value is not None:
if isinstance(value, dict):
Expand Down

0 comments on commit cc9f351

Please sign in to comment.