From 2d39ccad9d8a9b86229dd7fb1ce47ca841b730be Mon Sep 17 00:00:00 2001 From: "Amber Brown (HawkOwl)" Date: Sun, 15 May 2016 17:53:13 +0800 Subject: [PATCH 01/10] apply patch from lukasa, refs #8320 --- twisted/web/http.py | 154 +++++++++++++++++++++------------- twisted/web/iweb.py | 19 +++++ twisted/web/proxy.py | 6 +- twisted/web/server.py | 2 +- twisted/web/test/test_http.py | 133 ----------------------------- twisted/web/test/test_web.py | 30 ++++--- 6 files changed, 137 insertions(+), 207 deletions(-) diff --git a/twisted/web/http.py b/twisted/web/http.py index 2203cd36129..0bf06cec0eb 100644 --- a/twisted/web/http.py +++ b/twisted/web/http.py @@ -100,7 +100,8 @@ def _parseHeader(line): from twisted.internet.interfaces import IProtocol from twisted.protocols import policies, basic -from twisted.web.iweb import IRequest, IAccessLogFormatter +from twisted.web.iweb import ( + IRequest, IAccessLogFormatter, INonQueuedRequestFactory) from twisted.web.http_headers import Headers H2_ENABLED = False @@ -534,6 +535,10 @@ def rawDataReceived(self, data): NO_BODY_CODES = (204, 304) +# Sentinel object that detects people explicitly passing `queued` to Request. +_QUEUED_SENTINEL = object() + + @implementer(interfaces.IConsumer) class Request: """ @@ -587,11 +592,11 @@ class Request: _forceSSL = 0 _disconnected = False - def __init__(self, channel, queued): + def __init__(self, channel, queued=_QUEUED_SENTINEL): """ @param channel: the channel we're connected to. - @param queued: are we in the request queue, or can we start writing to - the transport? + @param queued: (deprecated) are we in the request queue, or can we + start writing to the transport? """ self.notifications = [] self.channel = channel @@ -600,11 +605,7 @@ def __init__(self, channel, queued): self.received_cookies = {} self.responseHeaders = Headers() self.cookies = [] # outgoing cookies - - if queued: - self.transport = StringTransport() - else: - self.transport = self.channel.transport + self.transport = self.channel.transport def _cleanup(self): @@ -616,45 +617,19 @@ def _cleanup(self): self.unregisterProducer() self.channel.requestDone(self) del self.channel - try: - self.content.close() - except OSError: - # win32 suckiness, no idea why it does this - pass - del self.content + if self.content is not None: + try: + self.content.close() + except OSError: + # win32 suckiness, no idea why it does this + pass + del self.content for d in self.notifications: d.callback(None) self.notifications = [] # methods for channel - end users should not use these - def noLongerQueued(self): - """ - Notify the object that it is no longer queued. - - We start writing whatever data we have to the transport, etc. - - This method is not intended for users. - """ - if not self.queued: - raise RuntimeError("noLongerQueued() got called unnecessarily.") - - self.queued = 0 - - # set transport to real one and send any buffer data - data = self.transport.getvalue() - self.transport = self.channel.transport - if data: - self.transport.write(data) - - # if we have producer, register it with transport - if (self.producer is not None) and not self.finished: - self.transport.registerProducer(self.producer, self.streamingProducer) - - # if we're finished, clean up - if self.finished: - self._cleanup() - def gotLength(self, length): """ Called when HTTP channel got length of content in this request. @@ -806,19 +781,13 @@ def registerProducer(self, producer, streaming): self.streamingProducer = streaming self.producer = producer - - if self.queued: - if streaming: - producer.pauseProducing() - else: - self.transport.registerProducer(producer, streaming) + self.transport.registerProducer(producer, streaming) def unregisterProducer(self): """ Unregister the producer. """ - if not self.queued: - self.transport.unregisterProducer() + self.transport.unregisterProducer() self.producer = None @@ -1264,11 +1233,12 @@ def isSecure(self): """ if self._forceSSL: return True - transport = getattr(getattr(self, 'channel', None), 'transport', None) + transport = getattr(self, 'transport', None) if interfaces.ISSLTransport(transport, None) is not None: return True return False + def _authorize(self): # Authorization, (mostly) per the RFC try: @@ -1585,6 +1555,34 @@ def noMoreData(self): +@implementer(interfaces.IPushProducer) +class _NoPushProducer(object): + """ + A no-op version of L{interfaces.IPushProducer}, used to abstract over the + possibility that a L{HTTPChannel} transport does not provide + L{IPushProducer}. + """ + def pauseProducing(self): + """ + Pause producing data. + + Tells a producer that it has produced too much data to process for + the time being, and to stop until resumeProducing() is called. + """ + pass + + + def resumeProducing(self): + """ + Resume producing data. + + This tells a producer to re-add itself to the main loop and produce + more data for its consumer. + """ + pass + + + class HTTPChannel(basic.LineReceiver, policies.TimeoutMixin): """ A receiver for HTTP requests. @@ -1605,6 +1603,19 @@ class HTTPChannel(basic.LineReceiver, policies.TimeoutMixin): @ivar _receivedHeaderSize: Bytes received so far for the header. @type _receivedHeaderSize: C{int} + + @ivar _handlingRequest: Whether a request is currently being processed. + @type _handlingRequest: L{bool} + + @ivar _dataBuffer: Any data that has been received from the connection + while processing an outstanding request. + @type _dataBuffer: L{list} of L{bytes} + + @ivar _producer: Either the transport, if it provides + L{interfaces.IPushProducer}, or a null implementation of + L{interfaces.IPushProducer}. Used to attempt to prevent the transport + from producing excess data when we're responding to a request. + @type _producer: L{interfaces.IPushProducer} """ maxHeaders = 500 @@ -1626,11 +1637,16 @@ class HTTPChannel(basic.LineReceiver, policies.TimeoutMixin): def __init__(self): # the request queue self.requests = [] + self._handlingRequest = False + self._dataBuffer = [] self._transferDecoder = None def connectionMade(self): self.setTimeout(self.timeOut) + self._producer = interfaces.IPushProducer( + self.transport, _NoPushProducer() + ) def lineReceived(self, line): @@ -1659,7 +1675,10 @@ def lineReceived(self, line): return # create a new Request object - request = self.requestFactory(self, len(self.requests)) + if INonQueuedRequestFactory.providedBy(self.requestFactory): + request = self.requestFactory(self) + else: + request = self.requestFactory(self, len(self.requests)) self.requests.append(request) self.__first_line = 0 @@ -1706,7 +1725,7 @@ def lineReceived(self, line): def _finishRequestBody(self, data): self.allContentReceived() - self.setLineMode(data) + self._dataBuffer.append(data) def headerReceived(self, line): @@ -1777,12 +1796,23 @@ def allContentReceived(self): if self.timeOut: self._savedTimeOut = self.setTimeout(None) + # Pause the producer if we can. If we can't, that's ok, we'll buffer. + self._producer.pauseProducing() + self._handlingRequest = True + req = self.requests[-1] req.requestReceived(command, path, version) def rawDataReceived(self, data): self.resetTimeout() + + # If we're currently handling a request, buffer this data. We shouldn't + # have received it (we've paused the transport), but let's be cautious. + if self._handlingRequest: + self._dataBuffer.append(data) + return + try: self._transferDecoder.dataReceived(data) except _MalformedChunkedDataError: @@ -1854,19 +1884,25 @@ def requestDone(self, request): del self.requests[0] if self.persistent: - # notify next request it can start writing - if self.requests: - self.requests[0].noLongerQueued() - else: - if self._savedTimeOut: - self.setTimeout(self._savedTimeOut) + self._handlingRequest = False + self._producer.resumeProducing() + + if self._savedTimeOut: + self.setTimeout(self._savedTimeOut) + + # Receive our buffered data, if any. + data = b''.join(self._dataBuffer) + self._dataBuffer = [] + self.setLineMode(data) else: self.transport.loseConnection() + def timeoutConnection(self): log.msg("Timing out client: %s" % str(self.transport.getPeer())) policies.TimeoutMixin.timeoutConnection(self) + def connectionLost(self, reason): self.setTimeout(None) for request in self.requests: diff --git a/twisted/web/iweb.py b/twisted/web/iweb.py index ce26f8c6f36..4805fb49d71 100644 --- a/twisted/web/iweb.py +++ b/twisted/web/iweb.py @@ -311,6 +311,25 @@ def setHost(host, port, ssl=0): +class INonQueuedRequestFactory(Interface): + """ + A factory of L{IRequest} objects that does not take a ``queued`` parameter. + """ + def __call__(channel): + """ + Create an L{IRequest} that is operating on the given channel. There + must only be one L{IRequest} object processing at any given time on a + channel. + + @param channel: A L{twisted.web.http.HTTPChannel} object. + @type channel: L{twisted.web.http.HTTPChannel} + + @return: A request object. + @rtype: L{IRequest} + """ + + + class IAccessLogFormatter(Interface): """ An object which can represent an HTTP request as a line of text for diff --git a/twisted/web/proxy.py b/twisted/web/proxy.py index 9f88a2e78db..95f70d6b84e 100644 --- a/twisted/web/proxy.py +++ b/twisted/web/proxy.py @@ -25,7 +25,7 @@ from twisted.internet.protocol import ClientFactory from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET -from twisted.web.http import HTTPClient, Request, HTTPChannel +from twisted.web.http import HTTPClient, Request, HTTPChannel, _QUEUED_SENTINEL @@ -134,7 +134,7 @@ class ProxyRequest(Request): protocols = {b'http': ProxyClientFactory} ports = {b'http': 80} - def __init__(self, channel, queued, reactor=reactor): + def __init__(self, channel, queued=_QUEUED_SENTINEL, reactor=reactor): Request.__init__(self, channel, queued) self.reactor = reactor @@ -195,7 +195,7 @@ class ReverseProxyRequest(Request): proxyClientFactoryClass = ProxyClientFactory - def __init__(self, channel, queued, reactor=reactor): + def __init__(self, channel, queued=_QUEUED_SENTINEL, reactor=reactor): Request.__init__(self, channel, queued) self.reactor = reactor diff --git a/twisted/web/server.py b/twisted/web/server.py index 3431dc8c986..a9ab407b8c9 100644 --- a/twisted/web/server.py +++ b/twisted/web/server.py @@ -620,7 +620,7 @@ class Site(http.HTTPFactory): A web site: manage log, sessions, and resources. @ivar counter: increment value used for generating unique sessions ID. - @ivar requestFactory: A factory which is called with (channel, queued) + @ivar requestFactory: A factory which is called with (channel) and creates L{Request} instances. Default to L{Request}. @ivar displayTracebacks: if set, Twisted internal errors are displayed on rendered pages. Default to C{True}. diff --git a/twisted/web/test/test_http.py b/twisted/web/test/test_http.py index 76910772000..57a9fd0d5b1 100644 --- a/twisted/web/test/test_http.py +++ b/twisted/web/test/test_http.py @@ -2035,67 +2035,6 @@ def test_registerProducerTwiceFails(self): ValueError, req.registerProducer, DummyProducer(), True) - def test_registerProducerWhenQueuedPausesPushProducer(self): - """ - Calling L{Request.registerProducer} with an IPushProducer when the - request is queued pauses the producer. - """ - req = http.Request(DummyChannel(), True) - producer = DummyProducer() - req.registerProducer(producer, True) - self.assertEqual(['pause'], producer.events) - - - def test_registerProducerWhenQueuedDoesntPausePullProducer(self): - """ - Calling L{Request.registerProducer} with an IPullProducer when the - request is queued does not pause the producer, because it doesn't make - sense to pause a pull producer. - """ - req = http.Request(DummyChannel(), True) - producer = DummyProducer() - req.registerProducer(producer, False) - self.assertEqual([], producer.events) - - - def test_registerProducerWhenQueuedDoesntRegisterPushProducer(self): - """ - Calling L{Request.registerProducer} with an IPushProducer when the - request is queued does not register the producer on the request's - transport. - """ - self.assertIdentical( - None, getattr(http.StringTransport, 'registerProducer', None), - "StringTransport cannot implement registerProducer for this test " - "to be valid.") - req = http.Request(DummyChannel(), True) - producer = DummyProducer() - req.registerProducer(producer, True) - # This is a roundabout assertion: http.StringTransport doesn't - # implement registerProducer, so Request.registerProducer can't have - # tried to call registerProducer on the transport. - self.assertIsInstance(req.transport, http.StringTransport) - - - def test_registerProducerWhenQueuedDoesntRegisterPullProducer(self): - """ - Calling L{Request.registerProducer} with an IPullProducer when the - request is queued does not register the producer on the request's - transport. - """ - self.assertIdentical( - None, getattr(http.StringTransport, 'registerProducer', None), - "StringTransport cannot implement registerProducer for this test " - "to be valid.") - req = http.Request(DummyChannel(), True) - producer = DummyProducer() - req.registerProducer(producer, False) - # This is a roundabout assertion: http.StringTransport doesn't - # implement registerProducer, so Request.registerProducer can't have - # tried to call registerProducer on the transport. - self.assertIsInstance(req.transport, http.StringTransport) - - def test_registerProducerWhenNotQueuedRegistersPushProducer(self): """ Calling L{Request.registerProducer} with an IPushProducer when the @@ -2243,38 +2182,6 @@ def test_unregisterNonQueuedStreamingProducer(self): self.assertEqual((None, None), (req.producer, req.transport.producer)) - def test_unregisterQueuedNonStreamingProducer(self): - """ - L{Request.unregisterProducer} unregisters a queued non-streaming - producer from the request but not from the transport. - """ - existing = DummyProducer() - channel = DummyChannel() - transport = StringTransport() - channel.transport = transport - transport.registerProducer(existing, True) - req = http.Request(channel, True) - req.registerProducer(DummyProducer(), False) - req.unregisterProducer() - self.assertEqual((None, existing), (req.producer, transport.producer)) - - - def test_unregisterQueuedStreamingProducer(self): - """ - L{Request.unregisterProducer} unregisters a queued streaming producer - from the request but not from the transport. - """ - existing = DummyProducer() - channel = DummyChannel() - transport = StringTransport() - channel.transport = transport - transport.registerProducer(existing, True) - req = http.Request(channel, True) - req.registerProducer(DummyProducer(), True) - req.unregisterProducer() - self.assertEqual((None, existing), (req.producer, transport.producer)) - - def test_finishProducesLog(self): """ L{http.Request.finish} will call the channel's factory to produce a log @@ -2491,46 +2398,6 @@ def test_expect100ContinueHeader(self): b"'''\n3\nabc'''\n")]) - def test_expect100ContinueWithPipelining(self): - """ - If a HTTP/1.1 client sends a 'Expect: 100-continue' header, followed - by another pipelined request, the 100 response does not interfere with - the response to the second request. - """ - transport = StringTransport() - channel = http.HTTPChannel() - channel.requestFactory = DummyHTTPHandler - channel.makeConnection(transport) - channel.dataReceived( - b"GET / HTTP/1.1\r\n" - b"Host: www.example.com\r\n" - b"Expect: 100-continue\r\n" - b"Content-Length: 3\r\n" - b"\r\nabc" - b"POST /foo HTTP/1.1\r\n" - b"Host: www.example.com\r\n" - b"Content-Length: 4\r\n" - b"\r\ndefg") - response = transport.value() - self.assertTrue( - response.startswith(b"HTTP/1.1 100 Continue\r\n\r\n")) - response = response[len(b"HTTP/1.1 100 Continue\r\n\r\n"):] - self.assertResponseEquals( - response, - [(b"HTTP/1.1 200 OK", - b"Command: GET", - b"Content-Length: 13", - b"Version: HTTP/1.1", - b"Request: /", - b"'''\n3\nabc'''\n"), - (b"HTTP/1.1 200 OK", - b"Command: POST", - b"Content-Length: 14", - b"Version: HTTP/1.1", - b"Request: /foo", - b"'''\n4\ndefg'''\n")]) - - def sub(keys, d): """ diff --git a/twisted/web/test/test_web.py b/twisted/web/test/test_web.py index 23eb86dc646..36e1ebf5f25 100644 --- a/twisted/web/test/test_web.py +++ b/twisted/web/test/test_web.py @@ -554,8 +554,10 @@ def test_processingFailedNoTraceback(self): fail = failure.Failure(Exception("Oh no!")) request.processingFailed(fail) - self.assertNotIn(b"Oh no!", request.transport.getvalue()) - self.assertIn(b"Processing Failed", request.transport.getvalue()) + self.assertNotIn(b"Oh no!", request.transport.written.getvalue()) + self.assertIn( + b"Processing Failed", request.transport.written.getvalue() + ) # Since we didn't "handle" the exception, flush it to prevent a test # failure @@ -574,7 +576,7 @@ def test_processingFailedDisplayTraceback(self): fail = failure.Failure(Exception("Oh no!")) request.processingFailed(fail) - self.assertIn(b"Oh no!", request.transport.getvalue()) + self.assertIn(b"Oh no!", request.transport.written.getvalue()) # Since we didn't "handle" the exception, flush it to prevent a test # failure @@ -594,7 +596,7 @@ def test_processingFailedDisplayTracebackHandlesUnicode(self): fail = failure.Failure(Exception(u"\u2603")) request.processingFailed(fail) - self.assertIn(b"☃", request.transport.getvalue()) + self.assertIn(b"☃", request.transport.written.getvalue()) # Since we didn't "handle" the exception, flush it to prevent a test # failure @@ -823,11 +825,15 @@ def _getReq(self, resource=None): def testGoodMethods(self): req = self._getReq() req.requestReceived(b'GET', b'/newrender', b'HTTP/1.0') - self.assertEqual(req.transport.getvalue().splitlines()[-1], b'hi hi') + self.assertEqual( + req.transport.written.getvalue().splitlines()[-1], b'hi hi' + ) req = self._getReq() req.requestReceived(b'HEH', b'/newrender', b'HTTP/1.0') - self.assertEqual(req.transport.getvalue().splitlines()[-1], b'ho ho') + self.assertEqual( + req.transport.written.getvalue().splitlines()[-1], b'ho ho' + ) def testBadMethods(self): req = self._getReq() @@ -855,7 +861,9 @@ def testImplicitHead(self): req = self._getReq() req.requestReceived(b'HEAD', b'/newrender', b'HTTP/1.0') self.assertEqual(req.code, 200) - self.assertEqual(-1, req.transport.getvalue().find(b'hi hi')) + self.assertEqual( + -1, req.transport.written.getvalue().find(b'hi hi') + ) def test_unsupportedHead(self): @@ -866,7 +874,7 @@ def test_unsupportedHead(self): resource = HeadlessResource() req = self._getReq(resource) req.requestReceived(b"HEAD", b"/newrender", b"HTTP/1.0") - headers, body = req.transport.getvalue().split(b'\r\n\r\n') + headers, body = req.transport.written.getvalue().split(b'\r\n\r\n') self.assertEqual(req.code, 200) self.assertEqual(body, b'') @@ -887,7 +895,7 @@ def __repr__(self): request.requestReceived(b"GET", b"/newrender", b"HTTP/1.0") - headers, body = request.transport.getvalue().split(b'\r\n\r\n') + headers, body = request.transport.written.getvalue().split(b'\r\n\r\n') self.assertEqual(request.code, 500) expected = [ '', @@ -988,7 +996,7 @@ def test_notAllowedQuoting(self): req.requestReceived(b'POST', b'/gettableresource?' b'value=