diff --git a/Lib/test/test_urllib2_localnet.py b/Lib/test/test_urllib2_localnet.py index 77dec0ce713ccb..a9e6832cfae923 100644 --- a/Lib/test/test_urllib2_localnet.py +++ b/Lib/test/test_urllib2_localnet.py @@ -7,6 +7,7 @@ import threading import unittest import hashlib +import mmap from test import support @@ -153,6 +154,25 @@ def _return_auth_challenge(self, request_handler): request_handler.wfile.write(b"Proxy Authentication Required.") return False + def _read_post_body(self, request_handler): + encoding = request_handler.headers.get("Transfer-Encoding") + if encoding is None: + length = int(request_handler.headers.get("Content-Length")) + if encoding == "chunked": + request = b"" + while True: + line = request_handler.rfile.readline() + length = int(line, 16) + if length == 0: + request_handler.rfile.readline() + break + request = request + request_handler.rfile.read(length) + request_handler.rfile.read(2) + else: + request = self.rfile.read(length) + + return request + def handle_request(self, request_handler): """Performs digest authentication on the given HTTP request handler. Returns True if authentication was successful, False @@ -162,6 +182,9 @@ def handle_request(self, request_handler): disabled and this method will always return True. """ + if request_handler.is_POST == True: + request_handler.post_body = self._read_post_body(request_handler) + if len(self._users) == 0: return True @@ -238,7 +261,39 @@ def do_GET(self): # Request Unauthorized self.do_AUTHHEAD() - + def do_POST(self): + encoding = self.headers.get("Transfer-Encoding") + if encoding is None: + length = int(self.headers.get("Content-Length")) + if encoding == "chunked": + request = b"" + while True: + line = self.rfile.readline() + length = int(line, 16) + if length == 0: + self.rfile.readline() + break + request = request + self.rfile.read(length) + self.rfile.read(2) + length = len(request) + else: + request = self.rfile.read(length) + if length > 0: + if not self.headers.get("Authorization", ""): + self.do_AUTHHEAD() + self.wfile.write(b"No Auth header received") + return + elif self.headers.get( + "Authorization", "") == "Basic " + self.ENCODED_AUTH: + self.send_response(200, "OK") + self.end_headers() + self.wfile.write(request) + else: + self.do_AUTHHEAD() + return + else: + self.send_response(400, "Empty request data") + self.end_headers() # Proxy test infrastructure @@ -264,6 +319,7 @@ def do_GET(self): (scm, netloc, path, params, query, fragment) = urllib.parse.urlparse( self.path, "http") self.short_path = path + self.is_POST = False if self.digest_auth_handler.handle_request(self): self.send_response(200, "OK") self.send_header("Content-Type", "text/html") @@ -273,6 +329,22 @@ def do_GET(self): self.wfile.write(b"Our apologies, but our server is down due to " b"a sudden zombie invasion.") + def do_POST(self): + (scm, netloc, path, params, query, fragment) = urllib.parse.urlparse( + self.path, "http") + self.short_path = path + self.is_POST = True + if self.digest_auth_handler.handle_request(self): + request = self.post_body + if len(request) > 0: + self.send_response(200, "OK") + self.send_header("Content-Type", "text/plain") + self.end_headers() + self.wfile.write(request) + else: + self.send_response(400, "Empty request data") + self.end_headers() + # Test cases class BasicAuthTests(unittest.TestCase): @@ -314,6 +386,21 @@ def test_basic_auth_httperror(self): urllib.request.install_opener(urllib.request.build_opener(ah)) self.assertRaises(urllib.error.HTTPError, urllib.request.urlopen, self.server_url) + def test_basic_auth_mmap_post(self): + data = "field=value".encode("ascii") + mem = mmap.mmap(-1, len(data)) + mem[:] = data + auth_handler = urllib.request.HTTPBasicAuthHandler() + auth_handler.add_password(realm=self.REALM, uri=self.server_url, + user=self.USER, passwd=self.PASSWD) + opener = urllib.request.build_opener(auth_handler) + urllib.request.install_opener(opener) + try: + response = urllib.request.urlopen(self.server_url, mem) + except urllib.error.HTTPError: + self.fail("Basic auth failed for the url: %s" % self.server_url) + response_data = response.read() + self.assertEqual(response_data, data) class ProxyAuthTests(unittest.TestCase): URL = "http://localhost" @@ -392,6 +479,15 @@ def test_proxy_qop_auth_int_works_or_throws_urlerror(self): pass result.close() + def test_proxy_auth_mmap_post(self): + data = "field=value".encode("ascii") + mem = mmap.mmap(-1, len(data)) + mem[:] = data + self.proxy_digest_handler.add_password(self.REALM, self.URL, + self.USER, self.PASSWD) + response = self.opener.open(self.URL, mem) + self.assertEqual(response.read(), data) + def GetRequestHandler(responses): diff --git a/Lib/urllib/request.py b/Lib/urllib/request.py index 9a3d399f018931..6112185564490a 100644 --- a/Lib/urllib/request.py +++ b/Lib/urllib/request.py @@ -936,6 +936,19 @@ def is_authenticated(self, authuri): if self.is_suburi(uri, reduced_authuri): return self.authenticated[uri] +class _AuthHandlerFileReiterator: + def __init__(self, file, startPosition): + self.file = file + self.startPosition = startPosition + self.chunksize = 8196 + + def __iter__(self): + self.file.seek(self.startPosition) + while True: + chunk = self.file.read(self.chunksize) + yield chunk + if len(chunk) < self.chunksize: + break class AbstractBasicAuthHandler: @@ -956,6 +969,7 @@ def __init__(self, password_mgr=None): password_mgr = HTTPPasswordMgr() self.passwd = password_mgr self.add_password = self.passwd.add_password + self.file_start_position = 0 def http_error_auth_reqed(self, authreq, host, req, headers): # host may be an authority (without userinfo) or a URL with an @@ -987,11 +1001,15 @@ def retry_http_basic_auth(self, host, req, realm): if req.get_header(self.auth_header, None) == auth: return None req.add_unredirected_header(self.auth_header, auth) + if hasattr(req._data, "read"): + req._data = _AuthHandlerFileReiterator(req._data, self.file_start_position) return self.parent.open(req, timeout=req.timeout) else: return None def http_request(self, req): + if hasattr(req._data, "read"): + self.file_start_position = req._data.tell() if (not hasattr(self.passwd, 'is_authenticated') or not self.passwd.is_authenticated(req.full_url)): return req @@ -1066,6 +1084,7 @@ def __init__(self, passwd=None): self.retried = 0 self.nonce_count = 0 self.last_nonce = None + self.file_start_position = 0 def reset_retry_count(self): self.retried = 0 @@ -1099,6 +1118,8 @@ def retry_http_digest_auth(self, req, auth): if req.headers.get(self.auth_header, None) == auth_val: return None req.add_unredirected_header(self.auth_header, auth_val) + if hasattr(req._data, "read"): + req._data = _AuthHandlerFileReiterator(req._data, self.file_start_position) resp = self.parent.open(req, timeout=req.timeout) return resp @@ -1190,6 +1211,12 @@ def get_entity_digest(self, data, chal): # XXX not implemented yet return None + def http_request(self, req): + if hasattr(req._data, "read"): + self.file_start_position = req._data.tell() + return req + + https_request = http_request class HTTPDigestAuthHandler(BaseHandler, AbstractDigestAuthHandler): """An authentication protocol defined by RFC 2069 diff --git a/Misc/NEWS.d/next/Library/2019-02-17-14-34-56.bpo-5038.IJlB1u.rst b/Misc/NEWS.d/next/Library/2019-02-17-14-34-56.bpo-5038.IJlB1u.rst new file mode 100644 index 00000000000000..cd1370f53eaf31 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2019-02-17-14-34-56.bpo-5038.IJlB1u.rst @@ -0,0 +1 @@ +urllib now allows mmap'ed files to be resent when the http server responds with 401 unauthorized for basic and digest authentication. \ No newline at end of file