Skip to content

Commit

Permalink
Hook up minimal buffering layer.
Browse files Browse the repository at this point in the history
  • Loading branch information
pythonspeed committed Sep 18, 2023
1 parent a18694b commit ef3e07f
Showing 1 changed file with 37 additions and 10 deletions.
47 changes: 37 additions & 10 deletions src/twisted/protocols/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,18 +695,36 @@ def clientConnectionForTLS(self, protocol):

class AggregateSmallWrites:
"""
Aggregate small writes so that we don't do expensive small writes into a
``OpenSSL.SSL.Connection`` instance.
Aggregate small writes so they get written in large batches.
If this is used as part of a transport, the transport needs to call
``flush()`` immediately when ``loseConnection()`` is called, otherwise any
buffered writes will never get written.
"""

def __init__(self, write: Callable[[bytes],object], clock: IReactorTime):
self._write = write
self._clock = clock
self._buffer = []

def write(self, data: bytes) -> None:
"""Buffer the data or write it immediately."""
was_empty = len(self._buffer) == 0
# TODO might want logic that flushes large writes immediately to reduce
# memory usage...
self._buffer.append(data)
if was_empty:
self._clock.callLater(0, self.flush)

# TODO implement buffering logic!
def flush(self) -> None:
"""Flush any buffered writes."""
if self._buffer:
self._write(b"".join(self._buffer))
del self._buffer[:]


@implementer(ITransport, INegotiated, ISSLTransport, IProtocol, ISystemHandle, IConsumer, ILoggingContext)
class BufferingTLSTransport(AggregateSmallWrites):
class BufferingTLSTransport:
"""
A TLS transport implemented by wrapping buffering around a
``TLSMemoryBIOProtocol``.
Expand All @@ -724,22 +742,31 @@ def __init__(
_connectWrapped: bool = True,
clock: Optional[IReactorTime] = None
):
if clock is None:
from twisted.internet import reactor
clock = cast(IReactorTime, reactor)
self._clock = clock
self._protocol = TLSMemoryBIOProtocol(factory, wrappedProtocol, _connectWrapped)

# Attributes we will forward to the wrapped protocol; "wrappedProtocol"
# is used in twisted/internet/endpoints.py, which is an abstraction
# violation...
self._forwarding_names : set[str] = {"wrappedProtocol"}
# TODO fix twisted.protocols.test.test_tls so we don't need these?
# {"transport", "_producer", "_tlsConnection", "_shutdownTLS"}
for interface in providedBy(self):
self._forwarding_names |= interface.names()

AggregateSmallWrites.__init__(self, self._protocol.write, clock)
if clock is None:
from twisted.internet import reactor
clock = cast(IReactorTime, reactor)
self._aggregator = AggregateSmallWrites(self._protocol.write, clock)

# TODO hook up AggregateSmallWrites to def write()/writeSequence()
def write(self, data):
self._aggregator.write(data)

def writeSequence(self, sequence):
self._aggregator.write(b"".join(sequence))

def loseConnection(self):
self._aggregator.flush()
self._protocol.loseConnection()

def __getattr__(self, attr):
if attr in self._forwarding_names:
Expand Down

0 comments on commit ef3e07f

Please sign in to comment.