Skip to content

Commit

Permalink
Refactor iosim code so clock advancing is built-in.
Browse files Browse the repository at this point in the history
  • Loading branch information
pythonspeed committed Oct 10, 2023
1 parent ba09799 commit a2f82ea
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 24 deletions.
30 changes: 24 additions & 6 deletions src/twisted/test/iosim.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Utilities and helpers for simulating a network
"""


from typing import Optional
import itertools

try:
Expand All @@ -20,6 +20,7 @@
from twisted.internet.endpoints import TCP4ClientEndpoint, TCP4ServerEndpoint
from twisted.internet.error import ConnectionRefusedError
from twisted.internet.protocol import Factory, Protocol
from twisted.internet.task import Clock
from twisted.internet.testing import MemoryReactorClock
from twisted.python.failure import Failure

Expand Down Expand Up @@ -284,12 +285,15 @@ class IOPump:
Perhaps this is a utility worthy of being in protocol.py?
"""

def __init__(self, client, server, clientIO, serverIO, debug):
def __init__(self, client, server, clientIO, serverIO, debug, clock=None):
self.client = client
self.server = server
self.clientIO = clientIO
self.serverIO = serverIO
self.debug = debug
if clock is None:
clock = MemoryReactorClock()
self.clock = clock

def flush(self, debug=False):
"""
Expand All @@ -298,7 +302,7 @@ def flush(self, debug=False):
Returns whether any data was moved.
"""
result = False
for x in range(1000):
for _ in range(1000):
if self.pump(debug):
result = True
else:
Expand All @@ -309,10 +313,11 @@ def flush(self, debug=False):

def pump(self, debug=False):
"""
Move data back and forth.
Move data back and forth, and increase clock slightly.
Returns whether any data was moved.
"""
self.clock.advance(0.000001)
if self.debug or debug:
print("-- GLUG --")
sData = self.serverIO.getOutBuffer()
Expand Down Expand Up @@ -356,6 +361,7 @@ def connect(
clientTransport,
debug=False,
greet=True,
clock=None,
):
"""
Create a new L{IOPump} connecting two protocols.
Expand Down Expand Up @@ -383,14 +389,22 @@ def connect(
post-server-greeting state?
@type greet: L{bool}
@param clock: An optional L{Clock}. Pumping the resulting L{IOPump} will
also increase clock time by a small increment.
@return: An L{IOPump} which connects C{serverProtocol} and
C{clientProtocol} and delivers bytes between them when it is pumped.
@rtype: L{IOPump}
"""
serverProtocol.makeConnection(serverTransport)
clientProtocol.makeConnection(clientTransport)
pump = IOPump(
clientProtocol, serverProtocol, clientTransport, serverTransport, debug
clientProtocol,
serverProtocol,
clientTransport,
serverTransport,
debug,
clock=clock,
)
if greet:
# Kick off server greeting, etc
Expand All @@ -405,6 +419,7 @@ def connectedServerAndClient(
serverTransportFactory=makeFakeServer,
debug=False,
greet=True,
clock: Optional[Clock] = None,
):
"""
Connect a given server and client class to each other.
Expand Down Expand Up @@ -435,6 +450,9 @@ def connectedServerAndClient(
post-server-greeting state?
@type greet: L{bool}
@param clock: An optional L{Clock}. Pumping the resulting L{IOPump} will
also increase clock time by a small increment.
@return: the client protocol, the server protocol, and an L{IOPump} which,
when its C{pump} and C{flush} methods are called, will move data
between the created client and server protocol instances.
Expand All @@ -444,7 +462,7 @@ def connectedServerAndClient(
s = ServerClass()
cio = clientTransportFactory(c)
sio = serverTransportFactory(s)
return c, s, connect(s, sio, c, cio, debug, greet)
return c, s, connect(s, sio, c, cio, debug, greet, clock=clock)


def _factoriesShouldConnect(clientInfo, serverInfo):
Expand Down
15 changes: 14 additions & 1 deletion src/twisted/test/test_iosim.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

from twisted.internet.interfaces import IPushProducer
from twisted.internet.protocol import Protocol
from twisted.test.iosim import FakeTransport, connect
from twisted.internet.task import Clock
from twisted.test.iosim import FakeTransport, connect, connectedServerAndClient
from twisted.trial.unittest import TestCase


Expand Down Expand Up @@ -299,3 +300,15 @@ def test_clientStreamingProducer(self) -> None:
(stream producer) registered with the client transport.
"""
self._testStreamingProducer(mode="client")

def test_timeAdvances(self) -> None:
"""
L{IOPump.pump} advances time in the given L{Clock}.
"""
time_passed = []
clock = Clock()
_, _, pump = connectedServerAndClient(Protocol, Protocol, clock=clock)
clock.callLater(0, lambda: time_passed.append(True))
self.assertFalse(time_passed)
pump.pump()
self.assertTrue(time_passed)
16 changes: 2 additions & 14 deletions src/twisted/test/test_sslverify.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,14 +300,8 @@ def connectionLost(self, reason):
sProto, cProto, pump = connectedServerAndClient(
lambda: serverFactory.buildProtocol(None),
lambda: clientFactory.buildProtocol(None),
clock=clock
)

# Need time to pass for flushing to work:
def flush(pump_flush=pump.flush):
clock.advance(0)
pump_flush()

pump.flush = flush
pump.flush()

return sProto, cProto, serverWrappedProto, clientWrappedProto, pump
Expand Down Expand Up @@ -2087,14 +2081,8 @@ def connectionLost(self, reason):
cProto, sProto, pump = connectedServerAndClient(
lambda: serverTLSFactory.buildProtocol(None),
lambda: clientTLSFactory.buildProtocol(None),
clock=clock
)

# Need time to pass for flushing to work:
def flush(pump_flush=pump.flush):
clock.advance(0)
pump_flush()

pump.flush = flush
pump.flush()

return cProto, sProto, clientWrappedProto, serverWrappedProto, pump
Expand Down
11 changes: 8 additions & 3 deletions src/twisted/web/test/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,8 +892,14 @@ def accumulator():
wrapper = serverWrapper(accumulator, reactor).buildProtocol(None)
serverTransport = FakeTransport(wrapper, True)
wrapper.makeConnection(serverTransport)
pump = IOPump(clientProtocol, wrapper, clientTransport, serverTransport, False)
reactor.advance(0)
pump = IOPump(
clientProtocol,
wrapper,
clientTransport,
serverTransport,
False,
clock=reactor,
)
pump.flush()
self.assertNoResult(deferred)
lines = accumulator.currentProtocol.data.split(b"\r\n")
Expand All @@ -907,7 +913,6 @@ def accumulator():
b"\r\nContent-length: 12\r\n\r\n"
b"hello world!"
)
reactor.advance(0)
pump.flush()
response = self.successResultOf(deferred)
self.assertEquals(
Expand Down

0 comments on commit a2f82ea

Please sign in to comment.