Skip to content

Commit

Permalink
Sketch of wrapper-based integration.
Browse files Browse the repository at this point in the history
  • Loading branch information
pythonspeed committed Sep 18, 2023
1 parent d760cde commit a18694b
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 9 deletions.
5 changes: 3 additions & 2 deletions src/twisted/internet/_newtls.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from twisted.internet.abstract import FileDescriptor
from twisted.internet.interfaces import ISSLTransport
from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol
from twisted.protocols.tls import TLSMemoryBIOFactory


class _BypassTLS:
Expand Down Expand Up @@ -128,7 +128,8 @@ def startTLS(transport, contextFactory, normal, bypass):
transport.unregisterProducer()

tlsFactory = TLSMemoryBIOFactory(contextFactory, client, None)
tlsProtocol = TLSMemoryBIOProtocol(tlsFactory, transport.protocol, False)
tlsProtocol = tlsFactory.protocol(tlsFactory, transport.protocol, False)
# Hook up the new TLS protocol to the transport:
transport.protocol = tlsProtocol

transport.getHandle = tlsProtocol.getHandle
Expand Down
10 changes: 5 additions & 5 deletions src/twisted/internet/test/test_newtls.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,17 @@ def __init__(self, producer, result):
self.result = result

def handshakeCompleted(self):
if not isinstance(self.transport.protocol, tls.TLSMemoryBIOProtocol):
if not isinstance(self.transport.protocol, tls.BufferingTLSTransport):
# Either the test or the code have a bug...
raise RuntimeError("TLSMemoryBIOProtocol not hooked up.")

self.transport.registerProducer(self.producer, True)
# The producer was registered with the TLSMemoryBIOProtocol:
self.result.append(self.transport.protocol._producer._producer)
self.result.append(self.transport.protocol._protocol._producer._producer)

self.transport.unregisterProducer()
# The producer was unregistered from the TLSMemoryBIOProtocol:
self.result.append(self.transport.protocol._producer)
self.result.append(self.transport.protocol._protocol._producer)
self.transport.loseConnection()


Expand Down Expand Up @@ -165,11 +165,11 @@ def connectionMade(self):
# status:
if streaming:
# _ProducerMembrane -> producer:
result.append(self.transport.protocol._producer._producer)
result.append(self.transport.protocol._protocol._producer._producer)
result.append(self.transport.producer._producer)
else:
# _ProducerMembrane -> _PullToPush -> producer:
result.append(self.transport.protocol._producer._producer._producer)
result.append(self.transport.protocol._protocol._producer._producer._producer)
result.append(self.transport.producer._producer._producer)
self.transport.unregisterProducer()
self.transport.loseConnection()
Expand Down
72 changes: 70 additions & 2 deletions src/twisted/protocols/tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
transports, such as UNIX sockets and stdio.
"""

from __future__ import annotations

from typing import cast, Callable, Optional, Type

from zope.interface import directlyProvides, implementer, providedBy

Expand All @@ -44,14 +47,19 @@
from twisted.internet._producer_helpers import _PullToPush
from twisted.internet._sslverify import _setAcceptableProtocols
from twisted.internet.interfaces import (
IConsumer,
IHandshakeListener,
ILoggingContext,
INegotiated,
IOpenSSLClientConnectionCreator,
IOpenSSLServerConnectionCreator,
IProtocol,
IProtocolNegotiationFactory,
IPushProducer,
IReactorTime,
ISystemHandle,
ISSLTransport,
ITransport,
)
from twisted.internet.main import CONNECTION_LOST
from twisted.internet.protocol import Protocol
Expand Down Expand Up @@ -116,7 +124,7 @@ def _representsEOF(exceptionObject: Error) -> bool:
return reasonString.casefold().startswith("unexpected eof")


@implementer(ISystemHandle, INegotiated)
@implementer(ISystemHandle, INegotiated, ITransport)
class TLSMemoryBIOProtocol(ProtocolWrapper):
"""
L{TLSMemoryBIOProtocol} is a protocol wrapper which uses OpenSSL via a
Expand All @@ -130,6 +138,9 @@ class TLSMemoryBIOProtocol(ProtocolWrapper):
merged using the L{_ProducerMembrane} wrapper. Non-streaming (pull)
producers are supported by wrapping them with L{_PullToPush}.
Because TLS may need to wait for reads before writing, some writes may be
buffered until a read occurs.
@ivar _tlsConnection: The L{OpenSSL.SSL.Connection} instance which is
encrypted and decrypting this connection.
Expand Down Expand Up @@ -682,6 +693,61 @@ def clientConnectionForTLS(self, protocol):
return self._connectionForTLS(protocol)


class AggregateSmallWrites:
"""
Aggregate small writes so that we don't do expensive small writes into a
``OpenSSL.SSL.Connection`` instance.
"""

def __init__(self, write: Callable[[bytes],object], clock: IReactorTime):
self._clock = clock

# TODO implement buffering logic!


@implementer(ITransport, INegotiated, ISSLTransport, IProtocol, ISystemHandle, IConsumer, ILoggingContext)
class BufferingTLSTransport(AggregateSmallWrites):
"""
A TLS transport implemented by wrapping buffering around a
``TLSMemoryBIOProtocol``.
Doing many small writes directly to a ``OpenSSL.SSL.Connection``, as
implemented in ``TLSMemoryBIOProtocol``, can add significant CPU and
bandwidth overhead. Thus, even when writing is possible, small writes will
get aggregated and written as a single write at the next reactor iteration.
"""

def __init__(
self,
factory: TLSMemoryBIOFactory,
wrappedProtocol: IProtocol,
_connectWrapped: bool = True,
clock: Optional[IReactorTime] = None
):
if clock is None:
from twisted.internet import reactor
clock = cast(IReactorTime, reactor)
self._clock = clock
self._protocol = TLSMemoryBIOProtocol(factory, wrappedProtocol, _connectWrapped)

# Attributes we will forward to the wrapped protocol; "wrappedProtocol"
# is used in twisted/internet/endpoints.py, which is an abstraction
# violation...
self._forwarding_names : set[str] = {"wrappedProtocol"}
for interface in providedBy(self):
self._forwarding_names |= interface.names()

AggregateSmallWrites.__init__(self, self._protocol.write, clock)

# TODO hook up AggregateSmallWrites to def write()/writeSequence()

def __getattr__(self, attr):
if attr in self._forwarding_names:
return getattr(self._protocol, attr)
else:
raise AttributeError("Unknown attribute", attr)


class TLSMemoryBIOFactory(WrappingFactory):
"""
L{TLSMemoryBIOFactory} adds TLS to connections.
Expand All @@ -696,7 +762,9 @@ class TLSMemoryBIOFactory(WrappingFactory):
L{TLSMemoryBIOProtocol} and returning L{OpenSSL.SSL.Connection}.
"""

protocol = TLSMemoryBIOProtocol
# BufferingTLSTransport wraps TLSMemoryBIOProtocol, which is a
# ProtocolWrapper.
protocol = cast(Type[ProtocolWrapper], BufferingTLSTransport)

noisy = False # disable unnecessary logging.

Expand Down

0 comments on commit a18694b

Please sign in to comment.