Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix infinite loop on wrong Digest Authentication #547

Merged
merged 2 commits into from Apr 13, 2012
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions requests/auth.py
Expand Up @@ -56,6 +56,8 @@ def __init__(self, username, password):
def handle_401(self, r):
"""Takes the given response and tries digest-auth, if needed."""

r.request.deregister_hook('response', self.handle_401)

s_auth = r.headers.get('www-authenticate', '')

if 'digest' in s_auth.lower():
Expand Down
62 changes: 39 additions & 23 deletions requests/models.py
Expand Up @@ -80,7 +80,7 @@ def __init__(self,
self.headers = dict(headers or [])

#: Dictionary of files to multipart upload (``{filename: content}``).
self.files = files
self.files = None

#: HTTP Method to use.
self.method = method
Expand Down Expand Up @@ -114,6 +114,7 @@ def __init__(self,

self.data, self._enc_data = self._encode_params(data)
self.params, self._enc_params = self._encode_params(params)
self.files, self._enc_files = self._encode_files(files)

#: :class:`Response <Response>` instance, containing
#: content and metadata of HTTP Response, once :attr:`sent <send>`.
Expand Down Expand Up @@ -329,6 +330,29 @@ def _encode_params(data):
else:
return data, data

def _encode_files(self,files):

if (not files) or isinstance(self.data, str):
return None, None

try:
fields = self.data.copy()
except AttributeError:
fields = dict(self.data)

for (k, v) in list(files.items()):
# support for explicit filename
if isinstance(v, (tuple, list)):
fn, fp = v
else:
fn = guess_filename(v) or k
fp = v
fields.update({k: (fn, fp.read())})

(body, content_type) = encode_multipart_formdata(fields)

return files, (body, content_type)

@property
def full_url(self):
"""Build the actual URL to use."""
Expand Down Expand Up @@ -408,7 +432,18 @@ def path_url(self):
def register_hook(self, event, hook):
"""Properly register a hook."""

return self.hooks[event].append(hook)
self.hooks[event].append(hook)

def deregister_hook(self,event,hook):
"""Deregister a previously registered hook.
Returns True if the hook existed, False if not.
"""

try:
self.hooks[event].remove(hook)
return True
except ValueError:
return False

def send(self, anyway=False, prefetch=False):
"""Sends the request. Returns True of successful, False if not.
Expand Down Expand Up @@ -436,26 +471,7 @@ def send(self, anyway=False, prefetch=False):

# Multi-part file uploads.
if self.files:
if not isinstance(self.data, str):

try:
fields = self.data.copy()
except AttributeError:
fields = dict(self.data)

for (k, v) in list(self.files.items()):
# support for explicit filename
if isinstance(v, (tuple, list)):
fn, fp = v
else:
fn = guess_filename(v) or k
fp = v
fields.update({k: (fn, fp.read())})

(body, content_type) = encode_multipart_formdata(fields)
else:
pass
# TODO: Conflict?
(body, content_type) = self._enc_files
else:
if self.data:

Expand Down Expand Up @@ -752,7 +768,7 @@ def content(self):
except AttributeError:
self._content = None

self._content_consumed = True
self._content_consumed = True
return self._content

def _detected_encoding(self):
Expand Down
14 changes: 14 additions & 0 deletions tests/test_requests.py
Expand Up @@ -272,6 +272,20 @@ def test_DIGESTAUTH_HTTP_200_OK_GET(self):
r = get(url, session=s)
self.assertEqual(r.status_code, 200)

def test_DIGESTAUTH_WRONG_HTTP_401_GET(self):

for service in SERVICES:

auth = HTTPDigestAuth('user', 'wrongpass')
url = service('digest-auth', 'auth', 'user', 'pass')

r = get(url, auth=auth)
self.assertEqual(r.status_code, 401)

s = requests.session(auth=auth)
r = get(url, session=s)
self.assertEqual(r.status_code, 401)

def test_POSTBIN_GET_POST_FILES(self):

for service in SERVICES:
Expand Down