From 0771f8c8fdbe1ca976e1c94a20f8a092f2f9eaf0 Mon Sep 17 00:00:00 2001 From: "Vladlen Y. Koshelev" Date: Wed, 13 Mar 2013 19:21:23 +0400 Subject: [PATCH] Add hooks for HTTPClient and HTTPConnection customization. --- tornado/simple_httpclient.py | 54 +++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/tornado/simple_httpclient.py b/tornado/simple_httpclient.py index ed344e7291..9b6fb6cf14 100644 --- a/tornado/simple_httpclient.py +++ b/tornado/simple_httpclient.py @@ -92,10 +92,12 @@ def _process_queue(self): request, callback = self.queue.popleft() key = object() self.active[key] = (request, callback) - _HTTPConnection(self.io_loop, self, request, - functools.partial(self._release_fetch, key), - callback, - self.max_buffer_size, self.resolver) + release_callback = functools.partial(self._release_fetch, key) + self._handle_request(request, release_callback, callback) + + def _handle_request(self, request, release_callback, final_callback): + _HTTPConnection(self.io_loop, self, request, release_callback, + final_callback, self.max_buffer_size, self.resolver) def _release_fetch(self, key): del self.active[key] @@ -153,8 +155,21 @@ def __init__(self, io_loop, client, request, release_callback, self.resolver.resolve(host, port, af, callback=self._on_resolve) def _on_resolve(self, addrinfo): - af, sockaddr = addrinfo[0] + self.stream = self._create_stream(addrinfo) + timeout = min(self.request.connect_timeout, self.request.request_timeout) + if timeout: + self._timeout = self.io_loop.add_timeout( + self.start_time + timeout, + stack_context.wrap(self._on_timeout)) + self.stream.set_close_callback(self._on_close) + # ipv6 addresses are broken (in self.parsed.hostname) until + # 2.7, here is correctly parsed value calculated in __init__ + sockaddr = addrinfo[0][1] + self.stream.connect(sockaddr, self._on_connect, + server_hostname=self.parsed_hostname) + def _create_stream(self, addrinfo): + af = addrinfo[0][0] if self.parsed.scheme == "https": ssl_options = {} if self.request.validate_cert: @@ -187,24 +202,14 @@ def _on_resolve(self, addrinfo): # information. ssl_options["ssl_version"] = ssl.PROTOCOL_SSLv3 - self.stream = SSLIOStream(socket.socket(af), - io_loop=self.io_loop, - ssl_options=ssl_options, - max_buffer_size=self.max_buffer_size) + return SSLIOStream(socket.socket(af), + io_loop=self.io_loop, + ssl_options=ssl_options, + max_buffer_size=self.max_buffer_size) else: - self.stream = IOStream(socket.socket(af), - io_loop=self.io_loop, - max_buffer_size=self.max_buffer_size) - timeout = min(self.request.connect_timeout, self.request.request_timeout) - if timeout: - self._timeout = self.io_loop.add_timeout( - self.start_time + timeout, - stack_context.wrap(self._on_timeout)) - self.stream.set_close_callback(self._on_close) - # ipv6 addresses are broken (in self.parsed.hostname) until - # 2.7, here is correctly parsed value calculated in __init__ - self.stream.connect(sockaddr, self._on_connect, - server_hostname=self.parsed_hostname) + return IOStream(socket.socket(af), + io_loop=self.io_loop, + max_buffer_size=self.max_buffer_size) def _on_timeout(self): self._timeout = None @@ -412,7 +417,7 @@ def _on_body(self, data): self.final_callback = None self._release() self.client.fetch(new_request, final_callback) - self.stream.close() + self._on_end_request() return if self._decompressor: data = (self._decompressor.decompress(data) + @@ -432,6 +437,9 @@ def _on_body(self, data): buffer=buffer, effective_url=self.request.url) self._run_callback(response) + self._on_end_request() + + def _on_end_request(self): self.stream.close() def _on_chunk_length(self, data):