Skip to content

Commit

Permalink
Simplify transport_serial (modbus use) (#1808)
Browse files Browse the repository at this point in the history
  • Loading branch information
janiversen committed Oct 11, 2023
1 parent 39177d7 commit 1d1750f
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 48 deletions.
82 changes: 35 additions & 47 deletions pymodbus/transport/transport_serial.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,27 @@
class SerialTransport(asyncio.Transport):
"""An asyncio serial transport."""

force_poll: bool = False

def __init__(self, loop, protocol, *args, **kwargs):
"""Initialize."""
super().__init__()
self.async_loop = loop
self._protocol: asyncio.BaseProtocol = protocol
self.sync_serial = serial.serial_for_url(*args, **kwargs)
self._write_buffer = []
self._has_reader = False
self._has_writer = False
self.poll_task = None
self._poll_wait_time = 0.0005
self.sync_serial.timeout = 0
self.sync_serial.write_timeout = 0

def setup(self):
"""Prepare to read/write"""
self.async_loop.call_soon(self._protocol.connection_made, self)
if os.name == "nt":
self._has_reader = self.async_loop.call_later(
self._poll_wait_time, self._poll_read
)
if os.name == "nt" or self.force_poll:
self.poll_task = asyncio.create_task(self._polling_task())
else:
self.async_loop.add_reader(self.sync_serial.fileno(), self._read_ready)
self._has_reader = True
self.async_loop.call_soon(self._protocol.connection_made, self)

def close(self, exc=None):
"""Close the transport gracefully."""
Expand All @@ -43,13 +41,13 @@ def close(self, exc=None):
with contextlib.suppress(Exception):
self.sync_serial.flush()

if self._has_reader:
if os.name == "nt":
self._has_reader.cancel()
else:
self.async_loop.remove_reader(self.sync_serial.fileno())
self._has_reader = False
self.flush()
if self.poll_task:
self.poll_task.cancel()
_ = asyncio.ensure_future(self.poll_task)
self.poll_task = None
else:
self.async_loop.remove_reader(self.sync_serial.fileno())
self.sync_serial.close()
self.sync_serial = None
with contextlib.suppress(Exception):
Expand All @@ -58,21 +56,13 @@ def close(self, exc=None):
def write(self, data):
"""Write some data to the transport."""
self._write_buffer.append(data)
if not self._has_writer:
if os.name == "nt":
self._has_writer = self.async_loop.call_soon(self._poll_write)
else:
self.async_loop.add_writer(self.sync_serial.fileno(), self._write_ready)
self._has_writer = True
if not self.poll_task:
self.async_loop.add_writer(self.sync_serial.fileno(), self._write_ready)

def flush(self):
"""Clear output buffer and stops any more data being written"""
if self._has_writer:
if os.name == "nt":
self._has_writer.cancel()
else:
self.async_loop.remove_writer(self.sync_serial.fileno())
self._has_writer = False
if not self.poll_task:
self.async_loop.remove_writer(self.sync_serial.fileno())
self._write_buffer.clear()

# ------------------------------------------------
Expand Down Expand Up @@ -141,34 +131,32 @@ def _write_ready(self):
"""Asynchronously write buffered data."""
data = b"".join(self._write_buffer)
try:
if nlen := self.sync_serial.write(data) < len(data):
self._write_buffer = data[nlen:]
return True
if (nlen := self.sync_serial.write(data)) < len(data):
self._write_buffer = [data[nlen:]]
if not self.poll_task:
self.async_loop.add_writer(
self.sync_serial.fileno(), self._write_ready
)
return
self.flush()
except (BlockingIOError, InterruptedError):
return True
return
except serial.SerialException as exc:
self.close(exc=exc)
return False

def _poll_read(self):
if self._has_reader:
try:
self._has_reader = self.async_loop.call_later(
self._poll_wait_time, self._poll_read
)
async def _polling_task(self):
"""Poll and try to read/write."""
try:
while True:
await asyncio.sleep(self._poll_wait_time)
while self._write_buffer:
self._write_ready()
if self.sync_serial.in_waiting:
self._read_ready()
except serial.SerialException as exc:
self.close(exc=exc)

def _poll_write(self):
if not self._has_writer:
return
if self._write_ready():
self._has_writer = self.async_loop.call_later(
self._poll_wait_time, self._poll_write
)
except serial.SerialException as exc:
self.close(exc=exc)
except asyncio.CancelledError:
pass


async def create_serial_connection(loop, protocol_factory, *args, **kwargs):
Expand Down
6 changes: 5 additions & 1 deletion test/sub_transport/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,10 @@ async def test_external_methods(self):
comm.close()
comm = SerialTransport(mock.MagicMock(), mock.Mock(), "dummy")
comm.abort()
assert await create_serial_connection(
transport, protocol = await create_serial_connection(
asyncio.get_running_loop(), mock.Mock, url="dummy"
)
await asyncio.sleep(0.1)
assert transport
assert protocol
transport.close()
56 changes: 56 additions & 0 deletions test/sub_transport/test_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
CommType,
ModbusProtocol,
)
from pymodbus.transport.transport_serial import SerialTransport


FACTOR = 1.2 if not pytest.IS_WINDOWS else 4.2
Expand Down Expand Up @@ -125,6 +126,61 @@ async def test_connected(self, client, server, use_comm_type):
assert not server.active_connections
server.transport_close()

def wrapped_write(self, data):
"""Wrap serial write, to split parameters."""
return self.serial_write(data[:2])

@pytest.mark.parametrize(
("use_comm_type", "use_host"),
[
(CommType.SERIAL, "socket://localhost:5020"),
],
)
async def test_split_serial_packet(self, client, server):
"""Test connection and data exchange."""
assert await server.transport_listen()
assert await client.transport_connect()
await asyncio.sleep(0.5)
assert len(server.active_connections) == 1
server_connected = list(server.active_connections.values())[0]
test_data = b"abcd"

self.serial_write = ( # pylint: disable=attribute-defined-outside-init
client.transport.sync_serial.write
)
with mock.patch.object(
client.transport.sync_serial, "write", wraps=self.wrapped_write
):
client.transport_send(test_data)
await asyncio.sleep(0.5)
assert server_connected.recv_buffer == test_data
assert not client.recv_buffer
client.transport_close()
server.transport_close()

@pytest.mark.parametrize(
("use_comm_type", "use_host"),
[
(CommType.SERIAL, "socket://localhost:5020"),
],
)
async def test_serial_poll(self, client, server):
"""Test connection and data exchange."""
assert await server.transport_listen()
SerialTransport.force_poll = True
assert await client.transport_connect()
await asyncio.sleep(0.5)
SerialTransport.force_poll = False
assert len(server.active_connections) == 1
server_connected = list(server.active_connections.values())[0]
test_data = b"abcd" * 1000
client.transport_send(test_data)
await asyncio.sleep(0.5)
assert server_connected.recv_buffer == test_data
assert not client.recv_buffer
client.transport_close()
server.transport_close()

@pytest.mark.parametrize(
("use_comm_type", "use_host"),
[
Expand Down

0 comments on commit 1d1750f

Please sign in to comment.