Skip to content

Commit

Permalink
Change HTTPHeaderDict superclass to MutableMapping
Browse files Browse the repository at this point in the history
Fix dict(HTTPHeaderDict(other_dict)) and deal with changes in how we
implement the HTTPHeaderDict as a MutableMapping subclass instead of
as a dict subclass.
  • Loading branch information
sigmavirus24 committed Jul 18, 2015
1 parent 838d23a commit 64adf9f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 22 deletions.
2 changes: 1 addition & 1 deletion test/test_collections.py
Expand Up @@ -237,7 +237,7 @@ def test_extend_from_headerdict(self):
def test_copy(self):
h = self.d.copy()
self.assertTrue(self.d is not h)
self.assertEqual(self.d, h)
self.assertEqual(self.d, h)

def test_getlist(self):
self.assertEqual(self.d.getlist('cookie'), ['foo', 'bar'])
Expand Down
48 changes: 27 additions & 21 deletions urllib3/_collections.py
Expand Up @@ -97,14 +97,7 @@ def keys(self):
return list(iterkeys(self._container))


_dict_setitem = dict.__setitem__
_dict_getitem = dict.__getitem__
_dict_delitem = dict.__delitem__
_dict_contains = dict.__contains__
_dict_setdefault = dict.setdefault


class HTTPHeaderDict(dict):
class HTTPHeaderDict(MutableMapping):
"""
:param headers:
An iterable of field-value pairs. Must not contain multiple field names
Expand Down Expand Up @@ -139,7 +132,8 @@ class HTTPHeaderDict(dict):
"""

def __init__(self, headers=None, **kwargs):
dict.__init__(self)
super(HTTPHeaderDict, self).__init__()
self._container = {}
if headers is not None:
if isinstance(headers, HTTPHeaderDict):
self._copy_from(headers)
Expand All @@ -149,24 +143,26 @@ def __init__(self, headers=None, **kwargs):
self.extend(kwargs)

def __setitem__(self, key, val):
return _dict_setitem(self, key.lower(), (key, val))
self._container[key.lower()] = (key, val)
return self._container[key.lower()]

def __getitem__(self, key):
val = _dict_getitem(self, key.lower())
val = self._container[key.lower()]
return ', '.join(val[1:])

def __delitem__(self, key):
return _dict_delitem(self, key.lower())
del self._container[key.lower()]

def __contains__(self, key):
return _dict_contains(self, key.lower())
return key.lower() in self._container

def __eq__(self, other):
if not isinstance(other, Mapping) and not hasattr(other, 'keys'):
return False
if not isinstance(other, type(self)):
other = type(self)(other)
return dict((k1, self[k1]) for k1 in self) == dict((k2, other[k2]) for k2 in other)
return (dict((k.lower(), v) for k, v in self.itermerged()) ==
dict((k.lower(), v) for k, v in other.itermerged()))

def __ne__(self, other):
return not self.__eq__(other)
Expand All @@ -181,6 +177,14 @@ def __ne__(self, other):

__marker = object()

def __len__(self):
return len(self._container)

def __iter__(self):
# Only provide the originally cased names
for vals in self._container.values():
yield vals[0]

def pop(self, key, default=__marker):
'''D.pop(k[,d]) -> v, remove specified key and return the corresponding value.
If key is not found, d is returned if given, otherwise KeyError is raised.
Expand Down Expand Up @@ -216,7 +220,7 @@ def add(self, key, val):
key_lower = key.lower()
new_vals = key, val
# Keep the common case aka no item present as fast as possible
vals = _dict_setdefault(self, key_lower, new_vals)
vals = self._container.setdefault(key_lower, new_vals)
if new_vals is not vals:
# new_vals was not inserted, as there was a previous one
if isinstance(vals, list):
Expand All @@ -225,7 +229,7 @@ def add(self, key, val):
else:
# vals should be a tuple then, i.e. only one item so far
# Need to convert the tuple to list for further extension
_dict_setitem(self, key_lower, [vals[0], vals[1], val])
self._container[key_lower] = [vals[0], vals[1], val]

def extend(self, *args, **kwargs):
"""Generic import function for any type of header-like object.
Expand Down Expand Up @@ -257,7 +261,7 @@ def getlist(self, key):
"""Returns a list of all the values for the named field. Returns an
empty list if the key doesn't exist."""
try:
vals = _dict_getitem(self, key.lower())
vals = self._container[key.lower()]
except KeyError:
return []
else:
Expand All @@ -276,11 +280,11 @@ def __repr__(self):

def _copy_from(self, other):
for key in other:
val = _dict_getitem(other, key)
val = other.getlist(key)
if isinstance(val, list):
# Don't need to convert tuples
val = list(val)
_dict_setitem(self, key, val)
self._container[key.lower()] = [key] + val

def copy(self):
clone = type(self)()
Expand All @@ -290,19 +294,21 @@ def copy(self):
def iteritems(self):
"""Iterate over all header lines, including duplicate ones."""
for key in self:
vals = _dict_getitem(self, key)
vals = self._container[key.lower()]
for val in vals[1:]:
yield vals[0], val

def itermerged(self):
"""Iterate over all headers, merging duplicate ones together."""
for key in self:
val = _dict_getitem(self, key)
val = self._container[key.lower()]
yield val[0], ', '.join(val[1:])

def items(self):
return list(self.iteritems())

viewitems = items

@classmethod
def from_httplib(cls, message): # Python 2
"""Read headers from a Python 2 httplib message object."""
Expand Down

0 comments on commit 64adf9f

Please sign in to comment.