Skip to content
Merged
56 changes: 44 additions & 12 deletions serialx/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,17 @@ def __init__(

self._wrap_exceptions = _wrap_exceptions

# Enter a "broken" state so that an error condition can persist
self._broken: Exception | None = None

def _mark_broken(self, exc: Exception) -> None:
if self._broken is None:
self._broken = exc

def _check_broken(self) -> None:
if self._broken is not None:
raise self._broken

@classmethod
def from_url(cls, url: str, *args: Any, **kwargs: Any) -> BaseSerial:
"""Create the appropriate serial port subclass for the given URL."""
Expand All @@ -393,9 +404,9 @@ def from_url(cls, url: str, *args: Any, **kwargs: Any) -> BaseSerial:
@maybe_wrap_exceptions
def open(self) -> None:
"""Open the serial port."""
self._open()

self._broken = None
try:
self._open()
self._configure_port()
except BaseException:
self.close()
Expand Down Expand Up @@ -436,8 +447,10 @@ def write_timeout(self) -> float | None:
"""Get the write timeout in seconds."""
return self._write_timeout

@maybe_wrap_exceptions
def get_modem_pins(self) -> ModemPins:
"""Get modem control bits."""
self._check_broken()
return self._get_modem_pins()

@maybe_wrap_exceptions
Expand All @@ -456,6 +469,8 @@ def set_modem_pins(
dsr: PinState | bool | None = PinState.UNDEFINED,
) -> None:
"""Set modem control bits."""
self._check_broken()

if modem_pins is None:
modem_pins = ModemPins(
le=PinState.convert(le),
Expand Down Expand Up @@ -484,6 +499,7 @@ def _set_modem_pins(self, modem_pins: ModemPins) -> None:
@maybe_wrap_exceptions
def readinto(self, b: Buffer, *, timeout: float | None = None) -> int:
"""Read bytes from serial port into buffer."""
self._check_broken()
timeout = self._read_timeout if timeout is None else timeout
return self._readinto(b, timeout=timeout)

Expand All @@ -495,6 +511,7 @@ def _readinto(self, b: Buffer, *, timeout: float | None) -> int:
@maybe_wrap_exceptions
def write(self, data: Buffer, *, timeout: float | None = None) -> int:
"""Write bytes to serial port."""
self._check_broken()
timeout = self._write_timeout if timeout is None else timeout
return self._write(data, timeout=timeout)

Expand Down Expand Up @@ -879,6 +896,14 @@ def __init__(
self._closing: bool = False
self._closed_waiter: asyncio.Future[None] = loop.create_future()

def _mark_broken(self, exc: Exception) -> None:
if self._serial is not None:
self._serial._mark_broken(exc)

def _check_broken(self) -> None:
if self._serial is not None:
self._serial._check_broken()

def is_closing(self) -> bool:
"""Return whether the transport is closing."""
return self._closing
Expand Down Expand Up @@ -981,19 +1006,18 @@ async def connect(

async def get_modem_pins(self) -> ModemPins:
"""Get modem control bits."""
assert self._serial is not None
return await self._loop.run_in_executor(None, self._serial.get_modem_pins)
self._check_broken()
return await self._get_modem_pins()

@abstractmethod
async def _get_modem_pins(self) -> ModemPins:
"""Get modem control bits, internal."""
raise NotImplementedError

@abstractmethod
async def _set_modem_pins(self, modem_pins: ModemPins) -> None:
"""Set modem control bits, internal."""
await self._loop.run_in_executor(
None,
lambda: (
self._serial._set_modem_pins(modem_pins)
if self._serial is not None
else None
),
)
raise NotImplementedError

async def set_modem_pins(
self,
Expand All @@ -1010,6 +1034,8 @@ async def set_modem_pins(
dsr: PinState | bool | None = PinState.UNDEFINED,
) -> None:
"""Set modem control bits."""
self._check_broken()

if modem_pins is None:
modem_pins = ModemPins(
le=PinState.convert(le),
Expand All @@ -1027,6 +1053,12 @@ async def set_modem_pins(

async def flush(self) -> None:
"""Flush write buffers, waiting until all data is written."""
self._check_broken()
await self._flush()

@abstractmethod
async def _flush(self) -> None:
"""Flush write buffers, waiting until all data is written, internal."""
raise NotImplementedError

async def wait_closed(self) -> None:
Expand Down
13 changes: 8 additions & 5 deletions serialx/descriptor_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,14 @@ def _read_ready(self) -> None:
if data:
self._protocol.data_received(data)
else:
LOGGER.info("%r was closed by peer", self)
self._closing = True
self._loop.remove_reader(self._fileno)
self._loop.call_soon(self._protocol.eof_received)
self._maybe_background_close(None)
# Linux's hung_up_tty_read returns 0 (drivers/tty/tty_io.c); surface
# this as -EIO to match what write/ioctl already get from the kernel,
# so consumers don't busy-loop on b''.
disconnect = OSError(
errno.EIO, "device disconnected or in use by another process"
)
self._mark_broken(disconnect)
self._close(disconnect)

def pause_reading(self) -> None:
"""Pause reading from the file descriptor."""
Expand Down
13 changes: 9 additions & 4 deletions serialx/platforms/serial_esphome.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,16 +691,21 @@ async def _async_close(self, api: APIClient) -> None:
finally:
self._call_protocol_connection_lost(None)

async def flush(self) -> None:
"""Flush write buffers."""
async def _flush(self) -> None:
"""Flush write buffers, waiting until all data is written, internal."""
assert self._serial is not None
await self._serial._async_flush()

async def get_modem_pins(self) -> ModemPins:
"""Get modem control bits."""
async def _get_modem_pins(self) -> ModemPins:
"""Get modem control bits, internal."""
assert self._serial is not None
return await self._serial._async_get_modem_pins()

async def _set_modem_pins(self, modem_pins: ModemPins) -> None:
"""Set modem control bits, internal."""
assert self._serial is not None
await self._serial._async_set_modem_pins(modem_pins)

def get_write_buffer_size(self) -> int:
"""Get the number of bytes currently in the write buffer."""
return 0
Expand Down
30 changes: 28 additions & 2 deletions serialx/platforms/serial_posix.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,14 @@ def _readinto(self, b: Buffer, *, timeout: float | None) -> int:
n = os.readinto(self._fileno, b)
LOGGER.debug("Read %d bytes", n)

if n == 0:
self._mark_broken(
OSError(
errno.EIO, "device disconnected or in use by another process"
)
)
self._check_broken()

return n

else:
Expand All @@ -428,6 +436,14 @@ def _readinto(self, b: Buffer, *, timeout: float | None) -> int:
m[:n] = chunk
LOGGER.debug("Read %d bytes: %r", n, chunk)

if n == 0:
self._mark_broken(
OSError(
errno.EIO, "device disconnected or in use by another process"
)
)
self._check_broken()

return n

def _write(self, data: Buffer, *, timeout: float | None) -> int:
Expand Down Expand Up @@ -509,8 +525,8 @@ async def _connect( # type: ignore[override]

self._protocol.connection_made(self)

async def flush(self) -> None:
"""Flush write buffers, waiting until all data is written."""
async def _flush(self) -> None:
"""Flush write buffers, waiting until all data is written, internal."""
assert self._serial is not None

try:
Expand All @@ -523,6 +539,16 @@ async def flush(self) -> None:
finally:
self._reset_empty_waiter()

async def _get_modem_pins(self) -> ModemPins:
"""Get modem control bits, internal."""
assert self._serial is not None
return await self._loop.run_in_executor(None, self._serial.get_modem_pins)

async def _set_modem_pins(self, modem_pins: ModemPins) -> None:
"""Set modem control bits, internal."""
assert self._serial is not None
await self._loop.run_in_executor(None, self._serial._set_modem_pins, modem_pins)


register_uri_handler(
scheme="device://",
Expand Down
8 changes: 4 additions & 4 deletions serialx/platforms/serial_pyodide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,8 @@ async def _reader_loop(self) -> None:
assert self._protocol is not None
self._protocol.data_received(bytes(result.value))

async def get_modem_pins(self) -> ModemPins:
"""Get modem control bits."""
async def _get_modem_pins(self) -> ModemPins:
"""Get modem control bits, internal."""
assert self._js_port is not None
result = await self._js_port.getSignals()

Expand Down Expand Up @@ -315,8 +315,8 @@ def get_write_buffer_size(self) -> int:
"""Return the number of bytes currently queued for writing."""
return self._write_buffer_size

async def flush(self) -> None:
"""Flush write buffers, waiting until all data is written."""
async def _flush(self) -> None:
"""Flush write buffers, waiting until all data is written, internal."""
await self._write_queue.join()

def abort(self) -> None:
Expand Down
30 changes: 20 additions & 10 deletions serialx/platforms/serial_rfc2217/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections.abc import Generator
from contextlib import suppress
from enum import IntEnum
import errno
import logging
import sys

Expand Down Expand Up @@ -537,7 +538,10 @@ def _recv_and_process(self) -> None:
n = self._socket.recv_into(buf)

if n == 0:
raise SerialException("RFC 2217 connection closed by server")
self._mark_broken(
OSError(errno.EIO, "RFC 2217 connection closed by server")
)
self._check_broken()
Comment thread
puddly marked this conversation as resolved.

raw = bytes(buf[:n])
LOGGER.debug("RX raw: %d bytes [%s]", n, raw.hex(" "))
Expand Down Expand Up @@ -633,7 +637,10 @@ def _readinto(self, b: Buffer, *, timeout: float | None) -> int:
timeout -= get_elapsed()

if n == 0:
return 0
self._mark_broken(
OSError(errno.EIO, "RFC 2217 connection closed by server")
)
self._check_broken()

raw = bytes(buf[:n])
LOGGER.debug("RX raw (readinto): %d bytes [%s]", n, raw.hex(" "))
Expand Down Expand Up @@ -902,12 +909,13 @@ async def _send_and_wait(self, cmd: Rfc2217Command) -> Rfc2217Command:

def write(self, data: bytes | bytearray | memoryview) -> None:
"""Write data to the serial port, escaping IAC bytes."""
assert self._tcp_transport is not None
if self._tcp_transport is None:
return
escaped = iac_escape(bytes(data))
LOGGER.debug("TX data: %d bytes (%d on wire)", len(data), len(escaped))
self._tcp_transport.write(escaped)

async def get_modem_pins(self) -> ModemPins:
async def _get_modem_pins(self) -> ModemPins:
"""Return modem pin state from the last NOTIFY-MODEMSTATE."""
assert self._serial is not None
return self._serial._engine.get_modem_pins()
Expand Down Expand Up @@ -964,17 +972,19 @@ def _tcp_connection_lost(self, exc: Exception | None) -> None:
self._closing = True
self._tcp_transport = None

# Fail any pending waiters
waiter_exc = exc or SerialException("RFC 2217 connection closed by server")
if exc is None:
exc = OSError(errno.EIO, "RFC 2217 connection closed by server")
self._mark_broken(exc)

Comment thread
puddly marked this conversation as resolved.
# Fail any pending waiters
for _expected, telnet_waiter in self._telnet_waiters:
if not telnet_waiter.done():
telnet_waiter.set_exception(waiter_exc)
telnet_waiter.set_exception(exc)
self._telnet_waiters.clear()

for rfc2217_waiter in self._rfc2217_waiters.values():
if not rfc2217_waiter.done():
rfc2217_waiter.set_exception(waiter_exc)
rfc2217_waiter.set_exception(exc)
self._rfc2217_waiters.clear()

if self._serial is not None:
Expand Down Expand Up @@ -1014,8 +1024,8 @@ def abort(self) -> None:
else:
self._tcp_connection_lost(None)

async def flush(self) -> None:
"""Wait for the server to acknowledge all preceding writes."""
async def _flush(self) -> None:
"""Flush write buffers, waiting until all data is written, internal."""
assert self._serial is not None
if not self._serial._engine.negotiated:
return
Expand Down
Loading
Loading