Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use HTTPHeaderDict in request #679

Merged
merged 4 commits into from Jul 21, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion test/test_collections.py
Expand Up @@ -237,7 +237,7 @@ def test_extend_from_headerdict(self):
def test_copy(self):
h = self.d.copy()
self.assertTrue(self.d is not h)
self.assertEqual(self.d, h)
self.assertEqual(self.d, h)

def test_getlist(self):
self.assertEqual(self.d.getlist('cookie'), ['foo', 'bar'])
Expand Down Expand Up @@ -302,6 +302,7 @@ def test_dict_conversion(self):
hdict = {'Content-Length': '0', 'Content-type': 'text/plain', 'Server': 'TornadoServer/1.2.3'}
h = dict(HTTPHeaderDict(hdict).items())
self.assertEqual(hdict, h)
self.assertEqual(hdict, dict(HTTPHeaderDict(hdict)))

def test_string_enforcement(self):
# This currently throws AttributeError on key.lower(), should probably be something nicer
Expand Down
19 changes: 18 additions & 1 deletion test/with_dummyserver/test_proxy_poolmanager.py
Expand Up @@ -9,6 +9,7 @@
DEFAULT_CA, DEFAULT_CA_BAD, get_unreachable_address)
from .. import TARPIT_HOST

from urllib3._collections import HTTPHeaderDict
from urllib3.poolmanager import proxy_from_url, ProxyManager
from urllib3.exceptions import (
MaxRetryError, SSLError, ProxyError, ConnectTimeoutError)
Expand Down Expand Up @@ -48,7 +49,7 @@ def test_nagle_proxy(self):

def test_proxy_conn_fail(self):
host, port = get_unreachable_address()
http = proxy_from_url('http://%s:%s/' % (host, port), retries=1)
http = proxy_from_url('http://%s:%s/' % (host, port), retries=1, timeout=0.05)
self.assertRaises(MaxRetryError, http.request, 'GET',
'%s/' % self.https_url)
self.assertRaises(MaxRetryError, http.request, 'GET',
Expand Down Expand Up @@ -223,6 +224,22 @@ def test_headers(self):
self.assertEqual(returned_headers.get('Host'),
'%s:%s'%(self.https_host,self.https_port))

def test_headerdict(self):
default_headers = HTTPHeaderDict(a='b')
proxy_headers = HTTPHeaderDict()
proxy_headers.add('foo', 'bar')

http = proxy_from_url(
self.proxy_url,
headers=default_headers,
proxy_headers=proxy_headers)

request_headers = HTTPHeaderDict(baz='quux')
r = http.request('GET', '%s/headers' % self.http_url, headers=request_headers)
returned_headers = json.loads(r.data.decode())
self.assertEqual(returned_headers.get('Foo'), 'bar')
self.assertEqual(returned_headers.get('Baz'), 'quux')

def test_proxy_pooling(self):
http = proxy_from_url(self.proxy_url)

Expand Down
37 changes: 36 additions & 1 deletion test/with_dummyserver/test_socketlevel.py
Expand Up @@ -13,6 +13,7 @@
from urllib3.util.ssl_ import HAS_SNI
from urllib3.util.timeout import Timeout
from urllib3.util.retry import Retry
from urllib3._collections import HTTPHeaderDict

from dummyserver.testcase import SocketDummyServerTestCase
from dummyserver.server import (
Expand Down Expand Up @@ -355,7 +356,7 @@ def echo_socket_handler(listener):
base_url = 'http://%s:%d' % (self.host, self.port)

# Define some proxy headers.
proxy_headers = {'For The Proxy': 'YEAH!'}
proxy_headers = HTTPHeaderDict({'For The Proxy': 'YEAH!'})
proxy = proxy_from_url(base_url, proxy_headers=proxy_headers)

conn = proxy.connection_from_url('http://www.google.com/')
Expand Down Expand Up @@ -617,6 +618,40 @@ def test_httplib_headers_case_insensitive(self):
r = pool.request('GET', '/')
self.assertEqual(HEADERS, dict(r.headers.items())) # to preserve case sensitivity

def test_headers_are_sent_with_the_original_case(self):
headers = {'foo': 'bar', 'bAz': 'quux'}
parsed_headers = {}

def socket_handler(listener):
sock = listener.accept()[0]

buf = b''
while not buf.endswith(b'\r\n\r\n'):
buf += sock.recv(65536)

headers_list = [header for header in buf.split(b'\r\n')[1:] if header]

for header in headers_list:
(key, value) = header.split(b': ')
parsed_headers[key.decode()] = value.decode()

# Send incomplete message (note Content-Length)
sock.send((
'HTTP/1.1 204 No Content\r\n'
'Content-Length: 0\r\n'
'\r\n').encode('utf-8'))

sock.close()

self._start_server(socket_handler)
expected_headers = {'Accept-Encoding': 'identity',
'Host': '{0}:{1}'.format(self.host, self.port)}
expected_headers.update(headers)

pool = HTTPConnectionPool(self.host, self.port, retries=False)
pool.request('GET', '/', headers=HTTPHeaderDict(headers))
self.assertEqual(expected_headers, parsed_headers)


class TestHEAD(SocketDummyServerTestCase):
def test_chunked_head_response_does_not_hang(self):
Expand Down
58 changes: 29 additions & 29 deletions urllib3/_collections.py
Expand Up @@ -97,14 +97,7 @@ def keys(self):
return list(iterkeys(self._container))


_dict_setitem = dict.__setitem__
_dict_getitem = dict.__getitem__
_dict_delitem = dict.__delitem__
_dict_contains = dict.__contains__
_dict_setdefault = dict.setdefault


class HTTPHeaderDict(dict):
class HTTPHeaderDict(MutableMapping):
"""
:param headers:
An iterable of field-value pairs. Must not contain multiple field names
Expand Down Expand Up @@ -139,7 +132,8 @@ class HTTPHeaderDict(dict):
"""

def __init__(self, headers=None, **kwargs):
dict.__init__(self)
super(HTTPHeaderDict, self).__init__()
self._container = {}
if headers is not None:
if isinstance(headers, HTTPHeaderDict):
self._copy_from(headers)
Expand All @@ -149,38 +143,44 @@ def __init__(self, headers=None, **kwargs):
self.extend(kwargs)

def __setitem__(self, key, val):
return _dict_setitem(self, key.lower(), (key, val))
self._container[key.lower()] = (key, val)
return self._container[key.lower()]

def __getitem__(self, key):
val = _dict_getitem(self, key.lower())
val = self._container[key.lower()]
return ', '.join(val[1:])

def __delitem__(self, key):
return _dict_delitem(self, key.lower())
del self._container[key.lower()]

def __contains__(self, key):
return _dict_contains(self, key.lower())
return key.lower() in self._container

def __eq__(self, other):
if not isinstance(other, Mapping) and not hasattr(other, 'keys'):
return False
if not isinstance(other, type(self)):
other = type(self)(other)
return dict((k1, self[k1]) for k1 in self) == dict((k2, other[k2]) for k2 in other)
return (dict((k.lower(), v) for k, v in self.itermerged()) ==
dict((k.lower(), v) for k, v in other.itermerged()))

def __ne__(self, other):
return not self.__eq__(other)

values = MutableMapping.values
get = MutableMapping.get
update = MutableMapping.update

if not PY3: # Python 2
iterkeys = MutableMapping.iterkeys
itervalues = MutableMapping.itervalues

__marker = object()

def __len__(self):
return len(self._container)

def __iter__(self):
# Only provide the originally cased names
for vals in self._container.values():
yield vals[0]

def pop(self, key, default=__marker):
'''D.pop(k[,d]) -> v, remove specified key and return the corresponding value.
If key is not found, d is returned if given, otherwise KeyError is raised.
Expand Down Expand Up @@ -216,7 +216,7 @@ def add(self, key, val):
key_lower = key.lower()
new_vals = key, val
# Keep the common case aka no item present as fast as possible
vals = _dict_setdefault(self, key_lower, new_vals)
vals = self._container.setdefault(key_lower, new_vals)
if new_vals is not vals:
# new_vals was not inserted, as there was a previous one
if isinstance(vals, list):
Expand All @@ -225,7 +225,7 @@ def add(self, key, val):
else:
# vals should be a tuple then, i.e. only one item so far
# Need to convert the tuple to list for further extension
_dict_setitem(self, key_lower, [vals[0], vals[1], val])
self._container[key_lower] = [vals[0], vals[1], val]

def extend(self, *args, **kwargs):
"""Generic import function for any type of header-like object.
Expand All @@ -236,7 +236,7 @@ def extend(self, *args, **kwargs):
raise TypeError("extend() takes at most 1 positional "
"arguments ({} given)".format(len(args)))
other = args[0] if len(args) >= 1 else ()

if isinstance(other, HTTPHeaderDict):
for key, val in other.iteritems():
self.add(key, val)
Expand All @@ -257,7 +257,7 @@ def getlist(self, key):
"""Returns a list of all the values for the named field. Returns an
empty list if the key doesn't exist."""
try:
vals = _dict_getitem(self, key.lower())
vals = self._container[key.lower()]
except KeyError:
return []
else:
Expand All @@ -276,11 +276,11 @@ def __repr__(self):

def _copy_from(self, other):
for key in other:
val = _dict_getitem(other, key)
val = other.getlist(key)
if isinstance(val, list):
# Don't need to convert tuples
val = list(val)
_dict_setitem(self, key, val)
self._container[key.lower()] = [key] + val

def copy(self):
clone = type(self)()
Expand All @@ -290,14 +290,14 @@ def copy(self):
def iteritems(self):
"""Iterate over all header lines, including duplicate ones."""
for key in self:
vals = _dict_getitem(self, key)
vals = self._container[key.lower()]
for val in vals[1:]:
yield vals[0], val

def itermerged(self):
"""Iterate over all headers, merging duplicate ones together."""
for key in self:
val = _dict_getitem(self, key)
val = self._container[key.lower()]
yield val[0], ', '.join(val[1:])

def items(self):
Expand All @@ -307,16 +307,16 @@ def items(self):
def from_httplib(cls, message): # Python 2
"""Read headers from a Python 2 httplib message object."""
# python2.7 does not expose a proper API for exporting multiheaders
# efficiently. This function re-reads raw lines from the message
# efficiently. This function re-reads raw lines from the message
# object and extracts the multiheaders properly.
headers = []

for line in message.headers:
if line.startswith((' ', '\t')):
key, value = headers[-1]
headers[-1] = (key, value + '\r\n' + line.rstrip())
continue

key, value = line.split(':', 1)
headers.append((key, value.strip()))

Expand Down