diff --git a/vumi/tests/fake_connection.py b/vumi/tests/fake_connection.py index 24bfd5767..327b83fd1 100644 --- a/vumi/tests/fake_connection.py +++ b/vumi/tests/fake_connection.py @@ -271,7 +271,7 @@ def __init__(self, handler): def endpoint(self): return self.fake_server.endpoint - def get_agent(self, reactor=None, contextFactory=None): + def get_agent(self, reactor=None, pool=None, contextFactory=None): """ Returns an IAgent that makes requests to this fake server. """ @@ -297,6 +297,7 @@ def render_PUT(self, request): class ProxyAgentWithContext(ProxyAgent): - def __init__(self, endpoint, reactor=None, contextFactory=None): + def __init__(self, endpoint, reactor=None, pool=None, contextFactory=None): self.contextFactory = contextFactory # To assert on in tests. - super(ProxyAgentWithContext, self).__init__(endpoint, reactor=reactor) + super(ProxyAgentWithContext, self).__init__( + endpoint, reactor=reactor, pool=pool) diff --git a/vumi/tests/test_utils.py b/vumi/tests/test_utils.py index a70a3ccfc..7478a3e0b 100644 --- a/vumi/tests/test_utils.py +++ b/vumi/tests/test_utils.py @@ -234,9 +234,9 @@ def test_http_request_with_custom_context_factory(self): ctxt = WebClientContextFactory() - def stashing_factory(reactor, contextFactory=None): + def stashing_factory(reactor, contextFactory=None, pool=None): agent = self.fake_http.get_agent( - reactor, contextFactory=contextFactory) + reactor, contextFactory=contextFactory, pool=pool) agents.append(agent) return agent diff --git a/vumi/utils.py b/vumi/utils.py index 0de4c0a25..a0bbb5be1 100644 --- a/vumi/utils.py +++ b/vumi/utils.py @@ -13,12 +13,13 @@ from twisted.internet import protocol from twisted.internet.defer import succeed from twisted.python.failure import Failure -from twisted.web.client import Agent, ResponseDone, WebClientContextFactory +from twisted.web.client import Agent, ResponseDone, HTTPConnectionPool from twisted.web.server import Site from twisted.web.http_headers import Headers from twisted.web.iweb import IBodyProducer from twisted.web.http import PotentialDataLoss from twisted.web.resource import Resource +from treq.client import HTTPClient from vumi.errors import VumiError @@ -119,13 +120,66 @@ def connectionLost(self, reason): def http_request_full(url, data=None, headers={}, method='POST', timeout=None, data_limit=None, context_factory=None, agent_class=None, reactor=None): + """ + This is a drop in replacement for the original `http_request_full` method + but it has its internals completely replaced by treq. Treq supports SNI + and our implementation does not for some reason. Also, we do not want + to continue maintaining this because we're favouring treq everywhere + anyway. + + """ + agent_class = agent_class or Agent + if reactor is None: + # The import replaces the local variable. + from twisted.internet import reactor + kwargs = {'pool': HTTPConnectionPool(reactor, persistent=False)} + if context_factory is not None: + kwargs['contextFactory'] = context_factory + agent = agent_class(reactor, **kwargs) + client = HTTPClient(agent) + + def handle_response(response): + return SimplishReceiver(response, data_limit).deferred + + d = client.request(method, url, headers=headers, data=data) + d.addCallback(handle_response) + + if timeout is not None: + cancelling_on_timeout = [False] + + def raise_timeout(reason): + if not cancelling_on_timeout[0] or reason.check(HttpTimeoutError): + return reason + return Failure(HttpTimeoutError("Timeout while connecting")) + + def cancel_on_timeout(): + cancelling_on_timeout[0] = True + d.cancel() + + def cancel_timeout(r, delayed_call): + if delayed_call.active(): + delayed_call.cancel() + return r + + d.addErrback(raise_timeout) + delayed_call = reactor.callLater(timeout, cancel_on_timeout) + d.addCallback(cancel_timeout, delayed_call) + + return d + + +def old_http_request_full(url, data=None, headers={}, method='POST', + timeout=None, data_limit=None, context_factory=None, + agent_class=None, reactor=None): if reactor is None: # The import replaces the local variable. from twisted.internet import reactor if agent_class is None: agent_class = Agent - context_factory = context_factory or WebClientContextFactory() - agent = agent_class(reactor, contextFactory=context_factory) + kwargs = {} + if context_factory is not None: + kwargs['contextFactory'] = context_factory + agent = agent_class(reactor, **kwargs) d = agent.request(method, url, mkheaders(headers),