Skip to content

Commit

Permalink
Switch to inheritance-based implementation since that is likely to cause
Browse files Browse the repository at this point in the history
breaking changes.

The composition-based version broke Tahoe, so this is not a theoretical concern.
  • Loading branch information
pythonspeed committed Sep 18, 2023
1 parent c4a41a4 commit c75b9ab
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 28 deletions.
8 changes: 4 additions & 4 deletions src/twisted/internet/test/test_newtls.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ def handshakeCompleted(self):

self.transport.registerProducer(self.producer, True)
# The producer was registered with the TLSMemoryBIOProtocol:
self.result.append(self.transport.protocol._protocol._producer._producer)
self.result.append(self.transport.protocol._producer._producer)

self.transport.unregisterProducer()
# The producer was unregistered from the TLSMemoryBIOProtocol:
self.result.append(self.transport.protocol._protocol._producer)
self.result.append(self.transport.protocol._producer)
self.transport.loseConnection()


Expand Down Expand Up @@ -165,11 +165,11 @@ def connectionMade(self):
# status:
if streaming:
# _ProducerMembrane -> producer:
result.append(self.transport.protocol._protocol._producer._producer)
result.append(self.transport.protocol._producer._producer)
result.append(self.transport.producer._producer)
else:
# _ProducerMembrane -> _PullToPush -> producer:
result.append(self.transport.protocol._protocol._producer._producer._producer)
result.append(self.transport.protocol._producer._producer._producer)
result.append(self.transport.producer._producer._producer)
self.transport.unregisterProducer()
self.transport.loseConnection()
Expand Down
37 changes: 13 additions & 24 deletions src/twisted/protocols/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,7 @@ def clientConnectionForTLS(self, protocol):
return self._connectionForTLS(protocol)


class AggregateSmallWrites:
class AggregateSmallWrites: # TODO make this private for now?
"""
Aggregate small writes so they get written in large batches.
Expand Down Expand Up @@ -737,8 +737,7 @@ def flush(self) -> None:
del self._buffer[:]


@implementer(ITransport, INegotiated, ISSLTransport, IProtocol, ISystemHandle, IConsumer, ILoggingContext)
class BufferingTLSTransport:
class BufferingTLSTransport(TLSMemoryBIOProtocol):
"""
A TLS transport implemented by wrapping buffering around a
``TLSMemoryBIOProtocol``.
Expand All @@ -749,28 +748,26 @@ class BufferingTLSTransport:
get aggregated and written as a single write at the next reactor iteration.
"""

# Note: An implementation based on composition would be nicer, but there's
# close integration between ``ProtocolWrapper`` subclasses like
# ``TLSMemoryBIOProtocol`` and the corresponding factory. Composition broke
# things like ``TLSMemoryBIOFactory.protocols`` having the correct
# instances, whereas subclassing makes that work.

def __init__(
self,
factory: TLSMemoryBIOFactory,
wrappedProtocol: IProtocol,
_connectWrapped: bool = True,
clock: Optional[IReactorTime] = None
):
self._protocol = TLSMemoryBIOProtocol(factory, wrappedProtocol, _connectWrapped)

# Attributes we will forward to the wrapped protocol; "wrappedProtocol"
# is used in twisted/internet/endpoints.py, which is an abstraction
# violation...
self._forwarding_names : set[str] = {"wrappedProtocol"}
# TODO fix twisted.protocols.test.test_tls so we don't need these?
# {"transport", "_producer", "_tlsConnection", "_shutdownTLS"}
for interface in providedBy(self):
self._forwarding_names |= interface.names()
super().__init__(factory, wrappedProtocol, _connectWrapped)

if clock is None:
from twisted.internet import reactor
clock = cast(IReactorTime, reactor)
self._aggregator = AggregateSmallWrites(self._protocol.write, clock)
actual_write = super().write
self._aggregator = AggregateSmallWrites(actual_write, clock)

def write(self, data):
self._aggregator.write(data)
Expand All @@ -780,13 +777,7 @@ def writeSequence(self, sequence):

def loseConnection(self):
self._aggregator.flush()
self._protocol.loseConnection()

def __getattr__(self, attr):
if attr in self._forwarding_names:
return getattr(self._protocol, attr)
else:
raise AttributeError("Unknown attribute", attr)
super().loseConnection()


class TLSMemoryBIOFactory(WrappingFactory):
Expand All @@ -803,9 +794,7 @@ class TLSMemoryBIOFactory(WrappingFactory):
L{TLSMemoryBIOProtocol} and returning L{OpenSSL.SSL.Connection}.
"""

# BufferingTLSTransport wraps TLSMemoryBIOProtocol, which is a
# ProtocolWrapper.
protocol = cast(Type[ProtocolWrapper], BufferingTLSTransport)
protocol = BufferingTLSTransport

noisy = False # disable unnecessary logging.

Expand Down

0 comments on commit c75b9ab

Please sign in to comment.