From 64adf9f9df6a4e76f921a13ec81744af7438c4b2 Mon Sep 17 00:00:00 2001 From: Ian Cordasco Date: Fri, 17 Jul 2015 22:58:14 -0500 Subject: [PATCH] Change HTTPHeaderDict superclass to MutableMapping Fix dict(HTTPHeaderDict(other_dict)) and deal with changes in how we implement the HTTPHeaderDict as a MutableMapping subclass instead of as a dict subclass. --- test/test_collections.py | 2 +- urllib3/_collections.py | 48 ++++++++++++++++++++++------------------ 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/test/test_collections.py b/test/test_collections.py index 9f7ce9a72d..9d72939c95 100644 --- a/test/test_collections.py +++ b/test/test_collections.py @@ -237,7 +237,7 @@ def test_extend_from_headerdict(self): def test_copy(self): h = self.d.copy() self.assertTrue(self.d is not h) - self.assertEqual(self.d, h) + self.assertEqual(self.d, h) def test_getlist(self): self.assertEqual(self.d.getlist('cookie'), ['foo', 'bar']) diff --git a/urllib3/_collections.py b/urllib3/_collections.py index 6f530cef35..92a2f8bcf9 100644 --- a/urllib3/_collections.py +++ b/urllib3/_collections.py @@ -97,14 +97,7 @@ def keys(self): return list(iterkeys(self._container)) -_dict_setitem = dict.__setitem__ -_dict_getitem = dict.__getitem__ -_dict_delitem = dict.__delitem__ -_dict_contains = dict.__contains__ -_dict_setdefault = dict.setdefault - - -class HTTPHeaderDict(dict): +class HTTPHeaderDict(MutableMapping): """ :param headers: An iterable of field-value pairs. Must not contain multiple field names @@ -139,7 +132,8 @@ class HTTPHeaderDict(dict): """ def __init__(self, headers=None, **kwargs): - dict.__init__(self) + super(HTTPHeaderDict, self).__init__() + self._container = {} if headers is not None: if isinstance(headers, HTTPHeaderDict): self._copy_from(headers) @@ -149,24 +143,26 @@ def __init__(self, headers=None, **kwargs): self.extend(kwargs) def __setitem__(self, key, val): - return _dict_setitem(self, key.lower(), (key, val)) + self._container[key.lower()] = (key, val) + return self._container[key.lower()] def __getitem__(self, key): - val = _dict_getitem(self, key.lower()) + val = self._container[key.lower()] return ', '.join(val[1:]) def __delitem__(self, key): - return _dict_delitem(self, key.lower()) + del self._container[key.lower()] def __contains__(self, key): - return _dict_contains(self, key.lower()) + return key.lower() in self._container def __eq__(self, other): if not isinstance(other, Mapping) and not hasattr(other, 'keys'): return False if not isinstance(other, type(self)): other = type(self)(other) - return dict((k1, self[k1]) for k1 in self) == dict((k2, other[k2]) for k2 in other) + return (dict((k.lower(), v) for k, v in self.itermerged()) == + dict((k.lower(), v) for k, v in other.itermerged())) def __ne__(self, other): return not self.__eq__(other) @@ -181,6 +177,14 @@ def __ne__(self, other): __marker = object() + def __len__(self): + return len(self._container) + + def __iter__(self): + # Only provide the originally cased names + for vals in self._container.values(): + yield vals[0] + def pop(self, key, default=__marker): '''D.pop(k[,d]) -> v, remove specified key and return the corresponding value. If key is not found, d is returned if given, otherwise KeyError is raised. @@ -216,7 +220,7 @@ def add(self, key, val): key_lower = key.lower() new_vals = key, val # Keep the common case aka no item present as fast as possible - vals = _dict_setdefault(self, key_lower, new_vals) + vals = self._container.setdefault(key_lower, new_vals) if new_vals is not vals: # new_vals was not inserted, as there was a previous one if isinstance(vals, list): @@ -225,7 +229,7 @@ def add(self, key, val): else: # vals should be a tuple then, i.e. only one item so far # Need to convert the tuple to list for further extension - _dict_setitem(self, key_lower, [vals[0], vals[1], val]) + self._container[key_lower] = [vals[0], vals[1], val] def extend(self, *args, **kwargs): """Generic import function for any type of header-like object. @@ -257,7 +261,7 @@ def getlist(self, key): """Returns a list of all the values for the named field. Returns an empty list if the key doesn't exist.""" try: - vals = _dict_getitem(self, key.lower()) + vals = self._container[key.lower()] except KeyError: return [] else: @@ -276,11 +280,11 @@ def __repr__(self): def _copy_from(self, other): for key in other: - val = _dict_getitem(other, key) + val = other.getlist(key) if isinstance(val, list): # Don't need to convert tuples val = list(val) - _dict_setitem(self, key, val) + self._container[key.lower()] = [key] + val def copy(self): clone = type(self)() @@ -290,19 +294,21 @@ def copy(self): def iteritems(self): """Iterate over all header lines, including duplicate ones.""" for key in self: - vals = _dict_getitem(self, key) + vals = self._container[key.lower()] for val in vals[1:]: yield vals[0], val def itermerged(self): """Iterate over all headers, merging duplicate ones together.""" for key in self: - val = _dict_getitem(self, key) + val = self._container[key.lower()] yield val[0], ', '.join(val[1:]) def items(self): return list(self.iteritems()) + viewitems = items + @classmethod def from_httplib(cls, message): # Python 2 """Read headers from a Python 2 httplib message object."""