Skip to content

Commit

Permalink
Issue #7776: Fix ``Host:'' header and reconnection when using http.cl…
Browse files Browse the repository at this point in the history
…ient.HTTPConnection.set_tunnel().

Patch by Nikolaus Rath.
  • Loading branch information
orsenthil committed Apr 14, 2014
1 parent b814057 commit 9da047b
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 26 deletions.
73 changes: 49 additions & 24 deletions Lib/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,22 +747,38 @@ def __init__(self, host, port=None, timeout=socket._GLOBAL_DEFAULT_TIMEOUT,
self._tunnel_port = None
self._tunnel_headers = {}

self._set_hostport(host, port)
(self.host, self.port) = self._get_hostport(host, port)

# This is stored as an instance variable to allow unit
# tests to replace it with a suitable mockup
self._create_connection = socket.create_connection

def set_tunnel(self, host, port=None, headers=None):
""" Sets up the host and the port for the HTTP CONNECT Tunnelling.
"""Set up host and port for HTTP CONNECT tunnelling.
In a connection that uses HTTP CONNECT tunneling, the host passed to the
constructor is used as a proxy server that relays all communication to
the endpoint passed to `set_tunnel`. This done by sending an HTTP
CONNECT request to the proxy server when the connection is established.
The headers argument should be a mapping of extra HTTP headers
to send with the CONNECT request.
This method must be called before the HTML connection has been
established.
The headers argument should be a mapping of extra HTTP headers to send
with the CONNECT request.
"""

if self.sock:
raise RuntimeError("Can't set up tunnel for established connection")

self._tunnel_host = host
self._tunnel_port = port
if headers:
self._tunnel_headers = headers
else:
self._tunnel_headers.clear()

def _set_hostport(self, host, port):
def _get_hostport(self, host, port):
if port is None:
i = host.rfind(':')
j = host.rfind(']') # ipv6 addresses have [...]
Expand All @@ -779,15 +795,16 @@ def _set_hostport(self, host, port):
port = self.default_port
if host and host[0] == '[' and host[-1] == ']':
host = host[1:-1]
self.host = host
self.port = port

return (host, port)

def set_debuglevel(self, level):
self.debuglevel = level

def _tunnel(self):
self._set_hostport(self._tunnel_host, self._tunnel_port)
connect_str = "CONNECT %s:%d HTTP/1.0\r\n" % (self.host, self.port)
(host, port) = self._get_hostport(self._tunnel_host,
self._tunnel_port)
connect_str = "CONNECT %s:%d HTTP/1.0\r\n" % (host, port)
connect_bytes = connect_str.encode("ascii")
self.send(connect_bytes)
for header, value in self._tunnel_headers.items():
Expand Down Expand Up @@ -815,8 +832,9 @@ def _tunnel(self):

def connect(self):
"""Connect to the host and port specified in __init__."""
self.sock = socket.create_connection((self.host,self.port),
self.timeout, self.source_address)
self.sock = self._create_connection((self.host,self.port),
self.timeout, self.source_address)

if self._tunnel_host:
self._tunnel()

Expand Down Expand Up @@ -985,22 +1003,29 @@ def putrequest(self, method, url, skip_host=0, skip_accept_encoding=0):
netloc_enc = netloc.encode("idna")
self.putheader('Host', netloc_enc)
else:
if self._tunnel_host:
host = self._tunnel_host
port = self._tunnel_port
else:
host = self.host
port = self.port

try:
host_enc = self.host.encode("ascii")
host_enc = host.encode("ascii")
except UnicodeEncodeError:
host_enc = self.host.encode("idna")
host_enc = host.encode("idna")

# As per RFC 273, IPv6 address should be wrapped with []
# when used as Host header

if self.host.find(':') >= 0:
if host.find(':') >= 0:
host_enc = b'[' + host_enc + b']'

if self.port == self.default_port:
if port == self.default_port:
self.putheader('Host', host_enc)
else:
host_enc = host_enc.decode("ascii")
self.putheader('Host', "%s:%s" % (host_enc, self.port))
self.putheader('Host', "%s:%s" % (host_enc, port))

# note: we are assuming that clients will not attempt to set these
# headers since *this* library must deal with the
Expand Down Expand Up @@ -1193,19 +1218,19 @@ def __init__(self, host, port=None, key_file=None, cert_file=None,
def connect(self):
"Connect to a host on a given (SSL) port."

sock = socket.create_connection((self.host, self.port),
self.timeout, self.source_address)
super().connect()

if self._tunnel_host:
self.sock = sock
self._tunnel()
server_hostname = self._tunnel_host
else:
server_hostname = self.host
sni_hostname = server_hostname if ssl.HAS_SNI else None

server_hostname = self.host if ssl.HAS_SNI else None
self.sock = self._context.wrap_socket(sock,
server_hostname=server_hostname)
self.sock = self._context.wrap_socket(self.sock,
server_hostname=sni_hostname)
if not self._context.check_hostname and self._check_hostname:
try:
ssl.match_hostname(self.sock.getpeercert(), self.host)
ssl.match_hostname(self.sock.getpeercert(), server_hostname)
except Exception:
self.sock.shutdown(socket.SHUT_RDWR)
self.sock.close()
Expand Down
50 changes: 48 additions & 2 deletions Lib/test/test_httplib.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
HOST = support.HOST

class FakeSocket:
def __init__(self, text, fileclass=io.BytesIO):
def __init__(self, text, fileclass=io.BytesIO, host=None, port=None):
if isinstance(text, str):
text = text.encode("ascii")
self.text = text
self.fileclass = fileclass
self.data = b''
self.sendall_calls = 0
self.host = host
self.port = port

def sendall(self, data):
self.sendall_calls += 1
Expand All @@ -38,6 +40,9 @@ def makefile(self, mode, bufsize=None):
raise client.UnimplementedFileMode()
return self.fileclass(self.text)

def close(self):
pass

class EPipeSocket(FakeSocket):

def __init__(self, text, pipe_trigger):
Expand Down Expand Up @@ -970,10 +975,51 @@ def test_getting_header_defaultint(self):
header = self.resp.getheader('No-Such-Header',default=42)
self.assertEqual(header, 42)

class TunnelTests(TestCase):

def test_connect(self):
response_text = (
'HTTP/1.0 200 OK\r\n\r\n' # Reply to CONNECT
'HTTP/1.1 200 OK\r\n' # Reply to HEAD
'Content-Length: 42\r\n\r\n'
)

def create_connection(address, timeout=None, source_address=None):
return FakeSocket(response_text, host=address[0],
port=address[1])

conn = client.HTTPConnection('proxy.com')
conn._create_connection = create_connection

# Once connected, we shouldn't be able to tunnel anymore
conn.connect()
self.assertRaises(RuntimeError, conn.set_tunnel,
'destination.com')

# But if we close the connection, we're good
conn.close()
conn.set_tunnel('destination.com')
conn.request('HEAD', '/', '')

self.assertEqual(conn.sock.host, 'proxy.com')
self.assertEqual(conn.sock.port, 80)
self.assertTrue(b'CONNECT destination.com' in conn.sock.data)
self.assertTrue(b'Host: destination.com' in conn.sock.data)

# This test should be removed when CONNECT gets the HTTP/1.1 blessing
self.assertTrue(b'Host: proxy.com' not in conn.sock.data)

conn.close()
conn.request('PUT', '/', '')
self.assertEqual(conn.sock.host, 'proxy.com')
self.assertEqual(conn.sock.port, 80)
self.assertTrue(b'CONNECT destination.com' in conn.sock.data)
self.assertTrue(b'Host: destination.com' in conn.sock.data)

def test_main(verbose=None):
support.run_unittest(HeaderTests, OfflineTest, BasicTest, TimeoutTest,
HTTPSTest, RequestBodyTest, SourceAddressTest,
HTTPResponseTest)
HTTPResponseTest, TunnelTests)

if __name__ == '__main__':
test_main()
3 changes: 3 additions & 0 deletions Misc/NEWS
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ Core and Builtins
Library
-------

- Issue #7776: Fix ``Host:'' header and reconnection when using
http.client.HTTPConnection.set_tunnel(). Patch by Nikolaus Rath.

- Issue #20968: unittest.mock.MagicMock now supports division.
Patch by Johannes Baiter.

Expand Down

0 comments on commit 9da047b

Please sign in to comment.