Skip to content
Permalink
Browse files

Merge commit 'dilate-half-close'

Add half-close to dilation API. Fix several bugs exposed by application-level
testing.

refs #344 but doesn't close it yet: I want to call `p.connectionLost()` after
both directions have been closed down
  • Loading branch information...
warner committed Aug 12, 2019
2 parents b233763 + 854f0d6 commit ab5fe65c3beee5d8cc0d3334beaf0bbac3418c32
@@ -205,8 +205,8 @@ def set_code(self, code):
self._did_start_code = True
self._C.set_code(code)

def dilate(self, no_listen=False):
return self._D.dilate(no_listen=no_listen) # fires with endpoints
def dilate(self, transit_relay_location=None, no_listen=False):
return self._D.dilate(transit_relay_location, no_listen=no_listen) # fires with endpoints

@m.input()
def send(self, plaintext):
@@ -10,7 +10,7 @@
from twisted.internet.endpoints import serverFromString
from twisted.internet.protocol import ClientFactory, ServerFactory
from twisted.internet.address import HostnameAddress, IPv4Address, IPv6Address
from twisted.internet.error import ConnectingCancelledError
from twisted.internet.error import ConnectingCancelledError, ConnectionRefusedError, DNSLookupError
from twisted.python import log
from .. import ipaddrs # TODO: move into _dilation/
from .._interfaces import IDilationConnector, IDilationManager
@@ -28,7 +28,9 @@

def build_sided_relay_handshake(key, side):
assert isinstance(side, type(u""))
assert len(side) == 8 * 2
# magic-wormhole-transit-relay expects a specific layout for the
# handshake message: "please relay {64} for side {16}\n"
assert len(side) == 8 * 2, side
token = HKDF(key, 32, CTXinfo=b"transit_relay_token")
return (b"please relay " + hexlify(token) +
b" for side " + side.encode("ascii") + b"\n")
@@ -310,7 +312,13 @@ def _schedule_connection(self, delay, h, is_relay):
d = deferLater(self._reactor, delay,
self._connect, ep, desc, is_relay)
d.addErrback(lambda f: f.trap(ConnectingCancelledError,
CancelledError))
ConnectionRefusedError,
CancelledError,
))
# TODO: HostnameEndpoint.connect catches CancelledError and replaces
# it with DNSLookupError. Remove this workaround when
# https://twistedmatrix.com/trac/ticket/9696 is fixed.
d.addErrback(lambda f: f.trap(DNSLookupError))
d.addErrback(log.err)
self._pending_connectors.add(d)

@@ -65,7 +65,7 @@ def __getitem__(self, n):
return (self.control, self.connect, self.listen)[n]

def make_side():
return bytes_to_hexstr(os.urandom(6))
return bytes_to_hexstr(os.urandom(8))


# new scheme:
@@ -552,6 +552,7 @@ def notify_stopped(self):
ABANDONING.upon(rx_HINTS, enter=ABANDONING, outputs=[]) # shouldn't happen
STOPPING.upon(rx_HINTS, enter=STOPPING, outputs=[])

WAITING.upon(stop, enter=STOPPED, outputs=[notify_stopped])
WANTING.upon(stop, enter=STOPPED, outputs=[notify_stopped])
CONNECTING.upon(stop, enter=STOPPED, outputs=[stop_connecting, notify_stopped])
CONNECTED.upon(stop, enter=STOPPING, outputs=[abandon_connection])
@@ -6,6 +6,7 @@
from twisted.internet.defer import inlineCallbacks, returnValue
from twisted.internet.interfaces import (ITransport, IProducer, IConsumer,
IAddress, IListeningPort,
IHalfCloseableProtocol,
IStreamClientEndpoint,
IStreamServerEndpoint)
from twisted.internet.error import ConnectionDone
@@ -55,6 +56,11 @@ class SingleUseEndpointError(Exception):
class AlreadyClosedError(Exception):
pass

class NormalCloseUsedOnHalfCloseable(Exception):
pass
class HalfCloseUsedOnNonHalfCloseable(Exception):
pass


@implementer(IAddress)
class _WormholeAddress(object):
@@ -87,11 +93,29 @@ def __attrs_post_init__(self):
# self._pending_outbound = {}
# self._processed = set()
self._protocol = None
self._pending_dataReceived = []
self._pending_connectionLost = (False, None)
self._pending_remote_data = []
self._pending_remote_close = False

@m.state(initial=True)
def open(self):
def unconnected(self):
pass # pragma: no cover

# once we get the IProtocol, it's either a IHalfCloseableProtocol, or it
# can only be fully closed
@m.state()
def open_half(self):
pass # pragma: no cover

@m.state()
def read_closed():
pass # pragma: no cover

@m.state()
def write_closed():
pass # pragma: no cover

@m.state()
def open_full(self):
pass # pragma: no cover

@m.state()
@@ -102,6 +126,14 @@ def closing():
def closed():
pass # pragma: no cover

@m.input()
def connect_protocol_half(self):
pass

@m.input()
def connect_protocol_full(self):
pass

@m.input()
def remote_data(self, data):
pass
@@ -118,6 +150,14 @@ def local_data(self, data):
def local_close(self):
pass

@m.output()
def queue_remote_data(self, data):
self._pending_remote_data.append(data)

@m.output()
def queue_remote_close(self):
self._pending_remote_close = True

@m.output()
def send_data(self, data):
self._manager.send_data(self._scid, data)
@@ -128,17 +168,24 @@ def send_close(self):

@m.output()
def signal_dataReceived(self, data):
if self._protocol:
self._protocol.dataReceived(data)
else:
self._pending_dataReceived.append(data)
assert self._protocol
self._protocol.dataReceived(data)

@m.output()
def signal_readConnectionLost(self):
IHalfCloseableProtocol(self._protocol).readConnectionLost()

@m.output()
def signal_writeConnectionLost(self):
IHalfCloseableProtocol(self._protocol).writeConnectionLost()

@m.output()
def signal_connectionLost(self):
if self._protocol:
self._protocol.connectionLost(ConnectionDone())
else:
self._pending_connectionLost = (True, ConnectionDone())
assert self._protocol
self._protocol.connectionLost(ConnectionDone())

@m.output()
def close_subchannel(self):
self._manager.subchannel_closed(self._scid, self)
# we're deleted momentarily

@@ -151,14 +198,44 @@ def error_closed_close(self):
raise AlreadyClosedError(
"loseConnection not allowed on closed subchannel")

# primary transitions
open.upon(remote_data, enter=open, outputs=[signal_dataReceived])
open.upon(local_data, enter=open, outputs=[send_data])
open.upon(remote_close, enter=closed, outputs=[send_close, signal_connectionLost])
open.upon(local_close, enter=closing, outputs=[send_close])
# stuff that arrives before we have a protocol connected
unconnected.upon(remote_data, enter=unconnected, outputs=[queue_remote_data])
unconnected.upon(remote_close, enter=unconnected, outputs=[queue_remote_close])

# IHalfCloseableProtocol flow
unconnected.upon(connect_protocol_half, enter=open_half, outputs=[])
open_half.upon(remote_data, enter=open_half, outputs=[signal_dataReceived])
open_half.upon(local_data, enter=open_half, outputs=[send_data])
# remote closes first
open_half.upon(remote_close, enter=read_closed, outputs=[signal_readConnectionLost])
read_closed.upon(local_data, enter=read_closed, outputs=[send_data])
read_closed.upon(local_close, enter=closed, outputs=[send_close,
close_subchannel,
# TODO: eventual-signal this?
signal_writeConnectionLost,
])
# local closes first
open_half.upon(local_close, enter=write_closed, outputs=[signal_writeConnectionLost,
send_close])
write_closed.upon(local_data, enter=write_closed, outputs=[error_closed_write])
write_closed.upon(remote_data, enter=write_closed, outputs=[signal_dataReceived])
write_closed.upon(remote_close, enter=closed, outputs=[close_subchannel,
signal_readConnectionLost,
])
# error cases
write_closed.upon(local_close, enter=write_closed, outputs=[error_closed_close])

# fully-closeable-only flow
unconnected.upon(connect_protocol_full, enter=open_full, outputs=[])
open_full.upon(remote_data, enter=open_full, outputs=[signal_dataReceived])
open_full.upon(local_data, enter=open_full, outputs=[send_data])
open_full.upon(remote_close, enter=closed, outputs=[send_close,
close_subchannel,
signal_connectionLost])
open_full.upon(local_close, enter=closing, outputs=[send_close])
closing.upon(remote_data, enter=closing, outputs=[signal_dataReceived])
closing.upon(remote_close, enter=closed, outputs=[signal_connectionLost])

closing.upon(remote_close, enter=closed, outputs=[close_subchannel,
signal_connectionLost])
# error cases
# we won't ever see an OPEN, since L4 will log+ignore those for us
closing.upon(local_data, enter=closing, outputs=[error_closed_write])
@@ -170,15 +247,19 @@ def error_closed_close(self):
def _set_protocol(self, protocol):
assert not self._protocol
self._protocol = protocol
if IHalfCloseableProtocol.providedBy(protocol):
self.connect_protocol_half()
else:
# move from UNCONNECTED to OPEN
self.connect_protocol_full();

def _deliver_queued_data(self):
if self._pending_dataReceived:
for data in self._pending_dataReceived:
self._protocol.dataReceived(data)
self._pending_dataReceived = []
cl, what = self._pending_connectionLost
if cl:
self._protocol.connectionLost(what)
for data in self._pending_remote_data:
self.remote_data(data)
del self._pending_remote_data
if self._pending_remote_close:
self.remote_close()
del self._pending_remote_close

# ITransport
def write(self, data):
@@ -189,7 +270,18 @@ def write(self, data):
def writeSequence(self, iovec):
self.write(b"".join(iovec))

def loseWriteConnection(self):
if not IHalfCloseableProtocol.providedBy(self._protocol):
# this is a clear error
raise HalfCloseUsedOnNonHalfCloseable()
self.local_close();

def loseConnection(self):
# TODO: what happens if an IHalfCloseableProtocol calls normal
# loseConnection()? I think we need to close the read side too.
if IHalfCloseableProtocol.providedBy(self._protocol):
# I don't know is correct, so avoid this for now
raise NormalCloseUsedOnHalfCloseable()
self.local_close()

def getHost(self):
@@ -216,7 +216,7 @@ class TestManager(unittest.TestCase):
def test_make_side(self):
side = make_side()
self.assertEqual(type(side), type(u""))
self.assertEqual(len(side), 2 * 6)
self.assertEqual(len(side), 2 * 8)

def test_create(self):
m, h = make_manager()

0 comments on commit ab5fe65

Please sign in to comment.
You can’t perform that action at this time.