Permalink
Browse files

Fixed default headers not getting fixed in some cases. (Fixed #99)

RequestMethods is now headers-aware and has an initializer. All
inheriting classes call it to set the default headers.
  • Loading branch information...
1 parent 7f48f6b commit dfe0ea1f96dc64a1b0dd26f3029cf3f239087745 @shazow committed Sep 15, 2012
Showing with 58 additions and 10 deletions.
  1. +2 −0 CHANGES.rst
  2. +4 −0 dummyserver/handlers.py
  3. +24 −0 test/with_dummyserver/test_poolmanager.py
  4. +5 −5 urllib3/connectionpool.py
  5. +8 −1 urllib3/poolmanager.py
  6. +15 −4 urllib3/request.py
View
@@ -6,6 +6,8 @@ dev (master branch)
* Exceptions are now pickleable, with tests. (Issue #101)
+* Fixed default headers not getting passed in some cases. (Issue #99)
+
1.5 (2012-08-02)
++++++++++++++++
@@ -1,6 +1,7 @@
from __future__ import print_function
import gzip
+import json
import logging
import sys
import time
@@ -161,5 +162,8 @@ def encodingrequest(self, request):
data = 'garbage'
return Response(data, headers=headers)
+ def headers(self, request):
+ return Response(json.dumps(request.headers))
+
def shutdown(self, request):
sys.exit()
@@ -1,4 +1,5 @@
import unittest
+import json
from dummyserver.testcase import HTTPDummyServerTestCase
from urllib3.poolmanager import PoolManager
@@ -62,6 +63,29 @@ def test_missing_port(self):
self.assertEqual(r.status, 200)
self.assertEqual(r.data, b'Dummy server!')
+ def test_headers(self):
+ http = PoolManager(headers={'Foo': 'bar'})
+
+ r = http.request_encode_url('GET', '%s/headers' % self.base_url)
+ returned_headers = json.loads(r.data.decode())
+ self.assertEqual(returned_headers.get('Foo'), 'bar')
+
+ r = http.request_encode_body('POST', '%s/headers' % self.base_url)
+ returned_headers = json.loads(r.data.decode())
+ self.assertEqual(returned_headers.get('Foo'), 'bar')
+
+ r = http.request_encode_url('GET', '%s/headers' % self.base_url, headers={'Baz': 'quux'})
+ returned_headers = json.loads(r.data.decode())
+ self.assertEqual(returned_headers.get('Foo'), None)
+ self.assertEqual(returned_headers.get('Baz'), 'quux')
+
+ r = http.request_encode_body('GET', '%s/headers' % self.base_url, headers={'Baz': 'quux'})
+ returned_headers = json.loads(r.data.decode())
+ self.assertEqual(returned_headers.get('Foo'), None)
+ self.assertEqual(returned_headers.get('Baz'), 'quux')
+
+
+
if __name__ == '__main__':
unittest.main()
@@ -166,13 +166,13 @@ class HTTPConnectionPool(ConnectionPool, RequestMethods):
def __init__(self, host, port=None, strict=False, timeout=None, maxsize=1,
block=False, headers=None):
- super(HTTPConnectionPool, self).__init__(host, port)
+ ConnectionPool.__init__(self, host, port)
+ RequestMethods.__init__(self, headers)
self.strict = strict
self.timeout = timeout
self.pool = self.QueueCls(maxsize)
self.block = block
- self.headers = headers or {}
# Fill the queue up so that doing get() on it will block properly
for _ in xrange(maxsize):
@@ -506,9 +506,9 @@ def __init__(self, host, port=None,
key_file=None, cert_file=None,
cert_reqs='CERT_NONE', ca_certs=None):
- super(HTTPSConnectionPool, self).__init__(host, port,
- strict, timeout, maxsize,
- block, headers)
+ HTTPConnectionPool.__init__(self, host, port,
+ strict, timeout, maxsize,
+ block, headers)
self.key_file = key_file
self.cert_file = cert_file
self.cert_reqs = cert_reqs
@@ -33,6 +33,10 @@ class PoolManager(RequestMethods):
Number of connection pools to cache before discarding the least
recently used pool.
+ :param headers:
+ Headers to include with all requests, unless other headers are given
+ explicitly.
+
:param \**connection_pool_kw:
Additional parameters are used to create fresh
:class:`urllib3.connectionpool.ConnectionPool` instances.
@@ -48,7 +52,8 @@ class PoolManager(RequestMethods):
"""
- def __init__(self, num_pools=10, **connection_pool_kw):
+ def __init__(self, num_pools=10, headers=None, **connection_pool_kw):
+ RequestMethods.__init__(self, headers)
self.connection_pool_kw = connection_pool_kw
self.pools = RecentlyUsedContainer(num_pools,
dispose_func=lambda p: p.close())
@@ -113,6 +118,8 @@ def urlopen(self, method, url, redirect=True, **kw):
kw['assert_same_host'] = False
kw['redirect'] = False
+ if 'headers' not in kw:
+ kw['headers'] = self.headers
response = conn.urlopen(method, u.request_uri, **kw)
View
@@ -36,12 +36,20 @@ class RequestMethods(object):
:meth:`.request` is for making any kind of request, it will look up the
appropriate encoding format and use one of the above two methods to make
the request.
+
+ Initializer parameters:
+
+ :param headers:
+ Headers to include with all requests, unless other headers are given
+ explicitly.
"""
_encode_url_methods = set(['DELETE', 'GET', 'HEAD', 'OPTIONS'])
-
_encode_body_methods = set(['PATCH', 'POST', 'PUT', 'TRACE'])
+ def __init__(self, headers=None):
+ self.headers = headers or {}
+
def urlopen(self, method, url, body=None, headers=None,
encode_multipart=True, multipart_boundary=None,
**kw): # Abstract
@@ -121,8 +129,11 @@ def request_encode_body(self, method, url, fields=None, headers=None,
body, content_type = (urlencode(fields or {}),
'application/x-www-form-urlencoded')
- headers = headers or {}
- headers.update({'Content-Type': content_type})
+ if headers is None:
+ headers = self.headers
+
+ headers_ = {'Content-Type': content_type}
+ headers_.update(headers)
- return self.urlopen(method, url, body=body, headers=headers,
+ return self.urlopen(method, url, body=body, headers=headers_,
**urlopen_kw)

0 comments on commit dfe0ea1

Please sign in to comment.