diff --git a/pymodbus/client/async/asyncio/__init__.py b/pymodbus/client/async/asyncio/__init__.py index 9a9ffff0d..3921d4409 100644 --- a/pymodbus/client/async/asyncio/__init__.py +++ b/pymodbus/client/async/asyncio/__init__.py @@ -659,7 +659,8 @@ class AsyncioModbusSerialClient(object): transport = None framer = None - def __init__(self, port, protocol_class=None, framer=None, loop=None): + def __init__(self, port, protocol_class=None, framer=None, loop=None, + baudrate=9600, bytesize=8, parity='N', stopbits=1): """ Initializes Asyncio Modbus Serial Client :param port: Port to connect @@ -674,30 +675,31 @@ def __init__(self, port, protocol_class=None, framer=None, loop=None): #: Event loop to use. self.loop = loop or asyncio.get_event_loop() self.port = port + self.baudrate = baudrate + self.bytesize = bytesize + self.parity = parity + self.stopbits = stopbits self.framer = framer - self._connected = False + self._connected_event = asyncio.Event() def stop(self): """ Stops connection :return: """ - if self._connected: + if self._connected.is_set(): if self.protocol: if self.protocol.transport: self.protocol.transport.close() def _create_protocol(self): - """ - Factory function to create initialized protocol instance. - """ - from serial_asyncio import create_serial_connection - - def factory(): - return self.protocol_class(framer=self.framer) + protocol = self.protocol_class(framer=self.framer) + protocol.factory = self + return protocol - cor = create_serial_connection(self.loop, factory, self.port) - return cor + @property + def _connected(self): + return self._connected_event.is_set() @asyncio.coroutine def connect(self): @@ -707,11 +709,16 @@ def connect(self): """ _logger.debug('Connecting.') try: - yield from self.loop.create_connection(self._create_protocol) - _logger.info('Connected to %s:%s.' % (self.host, self.port)) + from serial_asyncio import create_serial_connection + + yield from create_serial_connection( + self.loop, self._create_protocol, self.port, baudrate=self.baudrate, + bytesize=self.bytesize, stopbits=self.stopbits + ) + yield from self._connected_event.wait() + _logger.info('Connected to %s', self.port) except Exception as ex: - _logger.warning('Failed to connect: %s' % ex) - # asyncio.async(self._reconnect(), loop=self.loop) + _logger.warning('Failed to connect: %s', ex) def protocol_made_connection(self, protocol): """ @@ -719,7 +726,7 @@ def protocol_made_connection(self, protocol): """ _logger.info('Protocol made connection.') if not self._connected: - self._connected = True + self._connected_event.set() self.protocol = protocol else: _logger.error('Factory protocol connect ' @@ -735,7 +742,7 @@ def protocol_lost_connection(self, protocol): _logger.error('Factory protocol callback called' ' from unexpected protocol instance.') - self._connected = False + self._connected_event.clear() self.protocol = None # if self.host: # asyncio.async(self._reconnect(), loop=self.loop) diff --git a/pymodbus/client/async/factory/serial.py b/pymodbus/client/async/factory/serial.py index 60d147263..8a0a67b75 100644 --- a/pymodbus/client/async/factory/serial.py +++ b/pymodbus/client/async/factory/serial.py @@ -84,7 +84,7 @@ def async_io_factory(port=None, framer=None, **kwargs): Factory to create asyncio based async serial clients :param port: Serial port :param framer: Modbus Framer - :param kwargs: + :param kwargs: Serial port options :return: asyncio event loop and serial client """ import asyncio @@ -101,11 +101,9 @@ def async_io_factory(port=None, framer=None, **kwargs): import sys sys.exit(1) - client = AsyncioModbusSerialClient(port, proto_cls, framer, loop) - coro = client._create_protocol() - transport, protocol = loop.run_until_complete(asyncio.gather(coro))[0] - client.transport = transport - client.protocol = protocol + client = AsyncioModbusSerialClient(port, proto_cls, framer, loop, **kwargs) + coro = client.connect() + loop.run_until_complete(coro) return loop, client diff --git a/test/test_client_async.py b/test/test_client_async.py index 2c72c6b40..d16052767 100644 --- a/test/test_client_async.py +++ b/test/test_client_async.py @@ -39,25 +39,10 @@ # ---------------------------------------------------------------------------# -def mock_create_serial_connection(a, b, port): - ser = MagicMock() - ser.port = port - protocol = b() - transport = SerialTransport(a, protocol, ser) - protocol.transport = transport - return transport, protocol - - def mock_asyncio_gather(coro): return coro -def mock_asyncio_run_untill_complete(val): - transport, protocol = val - protocol._connected = True - return ([transport, protocol], ) - - class TestAsynchronousClient(object): """ This is the unittest for the pymodbus.client.async module @@ -237,36 +222,27 @@ def testSerialAsyncioClientPython2(self): assert pytest_wrapped_e.value.code == 1 @pytest.mark.skipif(not IS_PYTHON3 or PYTHON_VERSION < (3, 4), reason="requires python3.4 or above") - @patch("serial_asyncio.create_serial_connection", side_effect=mock_create_serial_connection) @patch("asyncio.get_event_loop") @patch("asyncio.gather", side_effect=mock_asyncio_gather) @pytest.mark.parametrize("method, framer", [("rtu", ModbusRtuFramer), ("socket", ModbusSocketFramer), ("binary", ModbusBinaryFramer), ("ascii", ModbusAsciiFramer)]) - def testSerialAsyncioClient(self, mock_gather, mock_event_loop, mock_serial_connection, method, framer): + def testSerialAsyncioClient(self, mock_gather, mock_event_loop, method, framer): """ - Test Serial async asyncio client exits on python2 + Test that AsyncModbusSerialClient instantiates AsyncioModbusSerialClient for asyncio scheduler. :return: """ loop = asyncio.get_event_loop() - loop.run_until_complete.side_effect = mock_asyncio_run_untill_complete - loop, client = AsyncModbusSerialClient(schedulers.ASYNC_IO, method=method, port=SERIAL_PORT, loop=loop) + loop, client = AsyncModbusSerialClient(schedulers.ASYNC_IO, method=method, port=SERIAL_PORT, loop=loop, + baudrate=19200, parity='E', stopbits=2, bytesize=7) assert(isinstance(client, AsyncioModbusSerialClient)) - assert(len(list(client.protocol.transaction)) == 0) assert(isinstance(client.framer, framer)) - assert(client.protocol._connected) - - d = client.protocol._buildResponse(0x00) - - def handle_failure(failure): - assert(isinstance(failure.exception(), ConnectionException)) - - d.add_done_callback(handle_failure) - assert(client.protocol._connected) - client.protocol.close() - assert(not client.protocol._connected) - pass + assert(client.port == SERIAL_PORT) + assert(client.baudrate == 19200) + assert(client.parity == 'E') + assert(client.stopbits == 2) + assert(client.bytesize == 7) # ---------------------------------------------------------------------------#