Skip to content

Commit

Permalink
Merge pull request #679 from sigmavirus24/httpheaderdict-requests
Browse files Browse the repository at this point in the history
Use HTTPHeaderDict in request
  • Loading branch information
shazow committed Jul 21, 2015
2 parents 7fcfcd0 + b49490d commit a76eb3b
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 32 deletions.
3 changes: 2 additions & 1 deletion test/test_collections.py
Expand Up @@ -237,7 +237,7 @@ def test_extend_from_headerdict(self):
def test_copy(self):
h = self.d.copy()
self.assertTrue(self.d is not h)
self.assertEqual(self.d, h)
self.assertEqual(self.d, h)

def test_getlist(self):
self.assertEqual(self.d.getlist('cookie'), ['foo', 'bar'])
Expand Down Expand Up @@ -302,6 +302,7 @@ def test_dict_conversion(self):
hdict = {'Content-Length': '0', 'Content-type': 'text/plain', 'Server': 'TornadoServer/1.2.3'}
h = dict(HTTPHeaderDict(hdict).items())
self.assertEqual(hdict, h)
self.assertEqual(hdict, dict(HTTPHeaderDict(hdict)))

def test_string_enforcement(self):
# This currently throws AttributeError on key.lower(), should probably be something nicer
Expand Down
19 changes: 18 additions & 1 deletion test/with_dummyserver/test_proxy_poolmanager.py
Expand Up @@ -9,6 +9,7 @@
DEFAULT_CA, DEFAULT_CA_BAD, get_unreachable_address)
from .. import TARPIT_HOST

from urllib3._collections import HTTPHeaderDict
from urllib3.poolmanager import proxy_from_url, ProxyManager
from urllib3.exceptions import (
MaxRetryError, SSLError, ProxyError, ConnectTimeoutError)
Expand Down Expand Up @@ -48,7 +49,7 @@ def test_nagle_proxy(self):

def test_proxy_conn_fail(self):
host, port = get_unreachable_address()
http = proxy_from_url('http://%s:%s/' % (host, port), retries=1)
http = proxy_from_url('http://%s:%s/' % (host, port), retries=1, timeout=0.05)
self.assertRaises(MaxRetryError, http.request, 'GET',
'%s/' % self.https_url)
self.assertRaises(MaxRetryError, http.request, 'GET',
Expand Down Expand Up @@ -223,6 +224,22 @@ def test_headers(self):
self.assertEqual(returned_headers.get('Host'),
'%s:%s'%(self.https_host,self.https_port))

def test_headerdict(self):
default_headers = HTTPHeaderDict(a='b')
proxy_headers = HTTPHeaderDict()
proxy_headers.add('foo', 'bar')

http = proxy_from_url(
self.proxy_url,
headers=default_headers,
proxy_headers=proxy_headers)

request_headers = HTTPHeaderDict(baz='quux')
r = http.request('GET', '%s/headers' % self.http_url, headers=request_headers)
returned_headers = json.loads(r.data.decode())
self.assertEqual(returned_headers.get('Foo'), 'bar')
self.assertEqual(returned_headers.get('Baz'), 'quux')

def test_proxy_pooling(self):
http = proxy_from_url(self.proxy_url)

Expand Down
37 changes: 36 additions & 1 deletion test/with_dummyserver/test_socketlevel.py
Expand Up @@ -14,6 +14,7 @@
from urllib3.util.ssl_ import HAS_SNI
from urllib3.util.timeout import Timeout
from urllib3.util.retry import Retry
from urllib3._collections import HTTPHeaderDict

from dummyserver.testcase import SocketDummyServerTestCase
from dummyserver.server import (
Expand Down Expand Up @@ -421,7 +422,7 @@ def echo_socket_handler(listener):
base_url = 'http://%s:%d' % (self.host, self.port)

# Define some proxy headers.
proxy_headers = {'For The Proxy': 'YEAH!'}
proxy_headers = HTTPHeaderDict({'For The Proxy': 'YEAH!'})
proxy = proxy_from_url(base_url, proxy_headers=proxy_headers)

conn = proxy.connection_from_url('http://www.google.com/')
Expand Down Expand Up @@ -683,6 +684,40 @@ def test_httplib_headers_case_insensitive(self):
r = pool.request('GET', '/')
self.assertEqual(HEADERS, dict(r.headers.items())) # to preserve case sensitivity

def test_headers_are_sent_with_the_original_case(self):
headers = {'foo': 'bar', 'bAz': 'quux'}
parsed_headers = {}

def socket_handler(listener):
sock = listener.accept()[0]

buf = b''
while not buf.endswith(b'\r\n\r\n'):
buf += sock.recv(65536)

headers_list = [header for header in buf.split(b'\r\n')[1:] if header]

for header in headers_list:
(key, value) = header.split(b': ')
parsed_headers[key.decode()] = value.decode()

# Send incomplete message (note Content-Length)
sock.send((
'HTTP/1.1 204 No Content\r\n'
'Content-Length: 0\r\n'
'\r\n').encode('utf-8'))

sock.close()

self._start_server(socket_handler)
expected_headers = {'Accept-Encoding': 'identity',
'Host': '{0}:{1}'.format(self.host, self.port)}
expected_headers.update(headers)

pool = HTTPConnectionPool(self.host, self.port, retries=False)
pool.request('GET', '/', headers=HTTPHeaderDict(headers))
self.assertEqual(expected_headers, parsed_headers)


class TestBrokenHeaders(SocketDummyServerTestCase):
def setUp(self):
Expand Down
58 changes: 29 additions & 29 deletions urllib3/_collections.py
Expand Up @@ -97,14 +97,7 @@ def keys(self):
return list(iterkeys(self._container))


_dict_setitem = dict.__setitem__
_dict_getitem = dict.__getitem__
_dict_delitem = dict.__delitem__
_dict_contains = dict.__contains__
_dict_setdefault = dict.setdefault


class HTTPHeaderDict(dict):
class HTTPHeaderDict(MutableMapping):
"""
:param headers:
An iterable of field-value pairs. Must not contain multiple field names
Expand Down Expand Up @@ -139,7 +132,8 @@ class HTTPHeaderDict(dict):
"""

def __init__(self, headers=None, **kwargs):
dict.__init__(self)
super(HTTPHeaderDict, self).__init__()
self._container = {}
if headers is not None:
if isinstance(headers, HTTPHeaderDict):
self._copy_from(headers)
Expand All @@ -149,38 +143,44 @@ def __init__(self, headers=None, **kwargs):
self.extend(kwargs)

def __setitem__(self, key, val):
return _dict_setitem(self, key.lower(), (key, val))
self._container[key.lower()] = (key, val)
return self._container[key.lower()]

def __getitem__(self, key):
val = _dict_getitem(self, key.lower())
val = self._container[key.lower()]
return ', '.join(val[1:])

def __delitem__(self, key):
return _dict_delitem(self, key.lower())
del self._container[key.lower()]

def __contains__(self, key):
return _dict_contains(self, key.lower())
return key.lower() in self._container

def __eq__(self, other):
if not isinstance(other, Mapping) and not hasattr(other, 'keys'):
return False
if not isinstance(other, type(self)):
other = type(self)(other)
return dict((k1, self[k1]) for k1 in self) == dict((k2, other[k2]) for k2 in other)
return (dict((k.lower(), v) for k, v in self.itermerged()) ==
dict((k.lower(), v) for k, v in other.itermerged()))

def __ne__(self, other):
return not self.__eq__(other)

values = MutableMapping.values
get = MutableMapping.get
update = MutableMapping.update

if not PY3: # Python 2
iterkeys = MutableMapping.iterkeys
itervalues = MutableMapping.itervalues

__marker = object()

def __len__(self):
return len(self._container)

def __iter__(self):
# Only provide the originally cased names
for vals in self._container.values():
yield vals[0]

def pop(self, key, default=__marker):
'''D.pop(k[,d]) -> v, remove specified key and return the corresponding value.
If key is not found, d is returned if given, otherwise KeyError is raised.
Expand Down Expand Up @@ -216,7 +216,7 @@ def add(self, key, val):
key_lower = key.lower()
new_vals = key, val
# Keep the common case aka no item present as fast as possible
vals = _dict_setdefault(self, key_lower, new_vals)
vals = self._container.setdefault(key_lower, new_vals)
if new_vals is not vals:
# new_vals was not inserted, as there was a previous one
if isinstance(vals, list):
Expand All @@ -225,7 +225,7 @@ def add(self, key, val):
else:
# vals should be a tuple then, i.e. only one item so far
# Need to convert the tuple to list for further extension
_dict_setitem(self, key_lower, [vals[0], vals[1], val])
self._container[key_lower] = [vals[0], vals[1], val]

def extend(self, *args, **kwargs):
"""Generic import function for any type of header-like object.
Expand All @@ -236,7 +236,7 @@ def extend(self, *args, **kwargs):
raise TypeError("extend() takes at most 1 positional "
"arguments ({} given)".format(len(args)))
other = args[0] if len(args) >= 1 else ()

if isinstance(other, HTTPHeaderDict):
for key, val in other.iteritems():
self.add(key, val)
Expand All @@ -257,7 +257,7 @@ def getlist(self, key):
"""Returns a list of all the values for the named field. Returns an
empty list if the key doesn't exist."""
try:
vals = _dict_getitem(self, key.lower())
vals = self._container[key.lower()]
except KeyError:
return []
else:
Expand All @@ -276,11 +276,11 @@ def __repr__(self):

def _copy_from(self, other):
for key in other:
val = _dict_getitem(other, key)
val = other.getlist(key)
if isinstance(val, list):
# Don't need to convert tuples
val = list(val)
_dict_setitem(self, key, val)
self._container[key.lower()] = [key] + val

def copy(self):
clone = type(self)()
Expand All @@ -290,14 +290,14 @@ def copy(self):
def iteritems(self):
"""Iterate over all header lines, including duplicate ones."""
for key in self:
vals = _dict_getitem(self, key)
vals = self._container[key.lower()]
for val in vals[1:]:
yield vals[0], val

def itermerged(self):
"""Iterate over all headers, merging duplicate ones together."""
for key in self:
val = _dict_getitem(self, key)
val = self._container[key.lower()]
yield val[0], ', '.join(val[1:])

def items(self):
Expand All @@ -307,16 +307,16 @@ def items(self):
def from_httplib(cls, message): # Python 2
"""Read headers from a Python 2 httplib message object."""
# python2.7 does not expose a proper API for exporting multiheaders
# efficiently. This function re-reads raw lines from the message
# efficiently. This function re-reads raw lines from the message
# object and extracts the multiheaders properly.
headers = []

for line in message.headers:
if line.startswith((' ', '\t')):
key, value = headers[-1]
headers[-1] = (key, value + '\r\n' + line.rstrip())
continue

key, value = line.split(':', 1)
headers.append((key, value.strip()))

Expand Down

0 comments on commit a76eb3b

Please sign in to comment.