Skip to content

Commit

Permalink
Refactored HTTPS code, added unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shazow committed Dec 11, 2009
1 parent c77409c commit eea79db
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 48 deletions.
27 changes: 25 additions & 2 deletions test/test_connectionpool.py
Expand Up @@ -3,7 +3,7 @@
import sys
sys.path.append('../')

from urllib3 import HTTPConnectionPool
from urllib3.connectionpool import HTTPConnectionPool, get_host, connection_from_url, HostChangedError

class TestConnectionPool(unittest.TestCase):
def test_get_host(self):
Expand All @@ -21,6 +21,29 @@ def test_get_host(self):
'https://google.com:8000': ('https', 'google.com', 8000),
}
for url, expected_host in url_host_map.iteritems():
returned_host = HTTPConnectionPool.get_host(url)
returned_host = get_host(url)
self.assertEquals(returned_host, expected_host)

def test_same_host(self):
same_host = [
('http://google.com/', '/'),
('http://google.com/', 'http://google.com/'),
('http://google.com/', 'http://google.com'),
('http://google.com/', 'http://google.com/abra/cadabra'),
('http://google.com:42/', 'http://google.com:42/abracadabra'),
]

for a,b in same_host:
c = connection_from_url(a)
self.assertTrue(c.is_same_host(b), "%s =? %s" % (a, b))

not_same_host = [
('http://yahoo.com/', 'http://google.com/'),
('http://google.com:42', 'https://google.com/abracadabra'),
('http://google.com', 'https://google.net/'),
]

for a,b in not_same_host:
c = connection_from_url(a)
self.assertFalse(c.is_same_host(b), "%s =? %s" % (a,b))

2 changes: 1 addition & 1 deletion urllib3/__init__.py
@@ -1,4 +1,4 @@
from connectionpool import HTTPConnectionPool
from connectionpool import HTTPConnectionPool, HTTPSConnectionPool, get_host, connection_from_url
from filepost import encode_multipart_formdata

# Possible exceptions
Expand Down
123 changes: 78 additions & 45 deletions urllib3/connectionpool.py
Expand Up @@ -25,6 +25,10 @@ class TimeoutError(HTTPError):
"Raised when a socket timeout occurs."
pass

class HostChangedError(HTTPError):
"Raised when an existing pool gets a request for a foreign host."
pass

## Response objects

class HTTPResponse(object):
Expand Down Expand Up @@ -94,13 +98,11 @@ class HTTPConnectionPool(object):
until a connection has been released. This is a useful side effect for
particular multithreaded situations where one does not want to use more
than maxsize connections per host to prevent flooding.
"""

ssl
If set to True, a HTTPSConnection is used instead of a HTTPConnection.
The default port for a HTTPSConnection is 443.
scheme = 'http'

"""
def __init__(self, host, port=None, timeout=None, maxsize=1, block=False, ssl=False):
def __init__(self, host, port=None, timeout=None, maxsize=1, block=False):
self.host = host
self.port = port
self.timeout = timeout
Expand All @@ -113,52 +115,13 @@ def __init__(self, host, port=None, timeout=None, maxsize=1, block=False, ssl=Fa
self.num_connections = 0
self.num_requests = 0

self.ConnectionCls = HTTPConnection
if ssl:
self.ConnectionCls = HTTPSConnection

@staticmethod
def get_host(url):
"""
Given a url, return its host and port (None if it's not there).
For example:
>>> HTTPConnectionPool.get_host('http://google.com/mail/')
google.com, None
>>> HTTPConnectionPool.get_host('google.com:80')
google.com, 80
"""
# This code is actually similar to urlparse.urlsplit, but much
# simplified for our needs.
port = None
scheme = 'http'
if '//' in url:
scheme, url = url.split('://', 1)
if '/' in url:
url, path = url.split('/', 1)
if ':' in url:
url, port = url.split(':', 1)
port = int(port)
return scheme, url, port

@staticmethod
def from_url(url, timeout=None, maxsize=10):
"""
Given a url, return an HTTPConnectionPool instance of its host.
This is a shortcut for not having to determine the host of the url
before creating an HTTPConnectionPool instance.
"""
scheme, host, port = HTTPConnectionPool.get_host(url)
return HTTPConnectionPool(host, port=port, timeout=timeout, maxsize=maxsize, ssl=scheme=='https')

def _new_conn(self):
"""
Return a fresh HTTPConnection.
"""
self.num_connections += 1
log.info("Starting new HTTP connection (%d): %s" % (self.num_connections, self.host))
return self.ConnectionCls(host=self.host, port=self.port)
return HTTPConnection(host=self.host, port=self.port)

def _get_conn(self, timeout=None):
"""
Expand Down Expand Up @@ -186,6 +149,9 @@ def _put_conn(self, conn):
# This should never happen if self.block == True
log.warning("HttpConnectionPool is full, discarding connection: %s" % self.host)

def is_same_host(self, url):
return url.startswith('/') or get_host(url) == (self.scheme, self.host, self.port)

def urlopen(self, method, url, body=None, headers={}, retries=3, redirect=True):
"""
Get a connection from the pool and perform an HTTP request.
Expand All @@ -210,6 +176,15 @@ def urlopen(self, method, url, body=None, headers={}, retries=3, redirect=True):
if retries < 0:
raise MaxRetryError("Max retries exceeded for url: %s" % url)

# Check host
if not self.is_same_host(url):
host = "%s://%s" % (self.scheme, self.host)
if self.port:
host = "%s:%d" % (host, self.port)

raise HostChangedError("Connection pool with host '%s' tried to open a foreign host: %s" % (host, url))


try:
# Request a connection from the queue
conn = self._get_conn()
Expand Down Expand Up @@ -274,3 +249,61 @@ def post_url(self, url, fields={}, headers={}, retries=3, redirect=True):
body, content_type = encode_multipart_formdata(fields)
headers.update({'Content-Type': content_type})
return self.urlopen('POST', url, body, headers=headers, retries=retries, redirect=redirect)


class HTTPSConnectionPool(HTTPConnectionPool):
"""
Same as HTTPConnectionPool, but HTTPS.
"""

scheme = 'https'

def _new_conn(self):
"""
Return a fresh HTTPSConnection.
"""
self.num_connections += 1
log.info("Starting new HTTPS connection (%d): %s" % (self.num_connections, self.host))
return HTTPSConnection(host=self.host, port=self.port)


## Helpers

def get_host(url):
"""
Given a url, return its scheme, host and port (None if it's not there).
For example:
>>> get_host('http://google.com/mail/')
http, google.com, None
>>> get_host('google.com:80')
http, google.com, 80
"""
# This code is actually similar to urlparse.urlsplit, but much
# simplified for our needs.
port = None
scheme = 'http'
if '//' in url:
scheme, url = url.split('://', 1)
if '/' in url:
url, path = url.split('/', 1)
if ':' in url:
url, port = url.split(':', 1)
port = int(port)
return scheme, url, port

def connection_from_url(url, **kw):
"""
Given a url, return an HTTP(S)ConnectionPool instance of its host.
This is a shortcut for not having to determine the host of the url
before creating an HTTP(S)ConnectionPool instance.
Passes on whatever kw arguments to the constructor of
HTTP(S)ConnectionPool. (e.g. timeout, maxsize, block)
"""
scheme, host, port = get_host(url)
if scheme == 'https':
return HTTPSConnectionPool(host, port=port, **kw)
else:
return HTTPConnectionPool(host, port=port, **kw)

0 comments on commit eea79db

Please sign in to comment.