From 7647e52943ad15da6bc7e52df3dfb744aa949bf8 Mon Sep 17 00:00:00 2001 From: Johannes Gorset Date: Sat, 21 Jan 2012 12:05:59 +0100 Subject: [PATCH] Facilitate for multiple hooks --- requests/hooks.py | 13 +++++++---- test_requests.py | 59 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 4 deletions(-) diff --git a/requests/hooks.py b/requests/hooks.py index 37f87d9399..dee4b1f627 100644 --- a/requests/hooks.py +++ b/requests/hooks.py @@ -31,10 +31,15 @@ def dispatch_hook(key, hooks, hook_data): hooks = hooks or dict() if key in hooks: - try: - return hooks.get(key).__call__(hook_data) or hook_data + hooks = hooks.get(key) - except Exception: - traceback.print_exc() + if hasattr(hooks, '__call__'): + hooks = [hooks] + + for hook in hooks: + try: + hook_data = hook(hook_data) or hook_data + except Exception: + traceback.print_exc() return hook_data diff --git a/test_requests.py b/test_requests.py index 172b1edaed..8915425565 100755 --- a/test_requests.py +++ b/test_requests.py @@ -516,6 +516,65 @@ def test_session_persistent_headers(self): self.assertEqual(r2.status_code, 200) + def test_single_hook(self): + + def add_foo_header(args): + if not args.get('headers'): + args['headers'] = {} + + args['headers'].update({ + 'X-Foo': 'foo' + }) + + return args + + for service in SERVICES: + url = service('headers') + + response = requests.get( + url = url, + hooks = { + 'args': add_foo_header + } + ) + + assert 'foo' in response.content + + def test_multiple_hooks(self): + + def add_foo_header(args): + if not args.get('headers'): + args['headers'] = {} + + args['headers'].update({ + 'X-Foo': 'foo' + }) + + return args + + def add_bar_header(args): + if not args.get('headers'): + args['headers'] = {} + + args['headers'].update({ + 'X-Bar': 'bar' + }) + + return args + + for service in SERVICES: + url = service('headers') + + response = requests.get( + url = url, + hooks = { + 'args': [add_foo_header, add_bar_header] + } + ) + + assert 'foo' in response.content + assert 'bar' in response.content + def test_session_persistent_cookies(self): s = requests.session()