diff --git a/pymodbus/client/sync.py b/pymodbus/client/sync.py index fb45ec6f7..9aa5306d4 100644 --- a/pymodbus/client/sync.py +++ b/pymodbus/client/sync.py @@ -195,7 +195,6 @@ def connect(self): try: self.socket = socket.create_connection( (self.host, self.port), - timeout=self.timeout, source_address=self.source_address) except socket.error as msg: _logger.error('Connection to (%s, %s) ' @@ -230,7 +229,27 @@ def _recv(self, size): """ if not self.socket: raise ConnectionException(self.__str__()) - return self.socket.recv(size) + + # socket.recv(size) waits until it gets some data from the host but + # not necessarily the entire response that can be fragmented in + # many packets. + # To avoid the splitted responses to be recognized as invalid messages + # and to be discarded, loops socket.recv until full data is received, + # or timeout is expired. + # If timeout expires returns the read data, also if its length is less + # than the expected size. + self.socket.setblocking(0) + begin = time.time() + + data = b'' + while(len(data) < size): + try: + data += self.socket.recv(size - len(data)) + except socket.error: + pass + if not self.timeout or (time.time() - begin > self.timeout): + break + return data def is_socket_open(self): return True if self.socket is not None else False @@ -320,7 +339,27 @@ def _recv(self, size): """ if not self.socket: raise ConnectionException(self.__str__()) - return self.socket.recvfrom(size)[0] + + # socket.recv(size) waits until it gets some data from the host but + # not necessarily the entire response that can be fragmented in + # many packets. + # To avoid the splitted responses to be recognized as invalid messages + # and to be discarded, loops socket.recv until full data is received, + # or timeout is expired. + # If timeout expires returns the read data, also if its length is less + # than the expected size. + self.socket.setblocking(0) + begin = time.time() + + data = b'' + while(len(data) < size): + try: + data += self.socket.recv(size - len(data)) + except socket.error: + pass + if not self.timeout or (time.time() - begin > self.timeout): + break + return data def is_socket_open(self): return True if self.socket is not None else False diff --git a/pymodbus/exceptions.py b/pymodbus/exceptions.py index a2ad48241..b225a4dd6 100644 --- a/pymodbus/exceptions.py +++ b/pymodbus/exceptions.py @@ -78,7 +78,7 @@ def __init__(self, string=""): ModbusException.__init__(self, message) -class InvalidMessageRecievedException(ModbusException): +class InvalidMessageReceivedException(ModbusException): """ Error resulting from invalid response received or decoded """ diff --git a/pymodbus/framer/rtu_framer.py b/pymodbus/framer/rtu_framer.py index b39649ea4..285fe6da9 100644 --- a/pymodbus/framer/rtu_framer.py +++ b/pymodbus/framer/rtu_framer.py @@ -2,7 +2,7 @@ import time from pymodbus.exceptions import ModbusIOException -from pymodbus.exceptions import InvalidMessageRecievedException +from pymodbus.exceptions import InvalidMessageReceivedException from pymodbus.utilities import checkCRC, computeCRC from pymodbus.utilities import hexlify_packets, ModbusTransactionState from pymodbus.compat import byte2int @@ -313,7 +313,7 @@ def _process(self, callback, error=False): if result is None: raise ModbusIOException("Unable to decode request") elif error and result.function_code < 0x80: - raise InvalidMessageRecievedException(result) + raise InvalidMessageReceivedException(result) else: self.populateResult(result) self.advanceFrame() diff --git a/pymodbus/framer/socket_framer.py b/pymodbus/framer/socket_framer.py index 37e3bfe9d..201018960 100644 --- a/pymodbus/framer/socket_framer.py +++ b/pymodbus/framer/socket_framer.py @@ -1,6 +1,6 @@ import struct from pymodbus.exceptions import ModbusIOException -from pymodbus.exceptions import InvalidMessageRecievedException +from pymodbus.exceptions import InvalidMessageReceivedException from pymodbus.utilities import hexlify_packets from pymodbus.framer import ModbusFramer, SOCKET_FRAME_HEADER @@ -174,7 +174,7 @@ def _process(self, callback, error=False): if result is None: raise ModbusIOException("Unable to decode request") elif error and result.function_code < 0x80: - raise InvalidMessageRecievedException(result) + raise InvalidMessageReceivedException(result) else: self.populateResult(result) self.advanceFrame() diff --git a/pymodbus/transaction.py b/pymodbus/transaction.py index 44beb8a79..dc9f841b1 100644 --- a/pymodbus/transaction.py +++ b/pymodbus/transaction.py @@ -6,7 +6,7 @@ from threading import RLock from pymodbus.exceptions import ModbusIOException, NotImplementedException -from pymodbus.exceptions import InvalidMessageRecievedException +from pymodbus.exceptions import InvalidMessageReceivedException from pymodbus.constants import Defaults from pymodbus.framer.ascii_framer import ModbusAsciiFramer from pymodbus.framer.rtu_framer import ModbusRtuFramer @@ -74,14 +74,9 @@ def _set_adu_size(self): self.base_adu_size = 7 # start(1)+ Address(2), LRC(2) + end(2) elif isinstance(self.client.framer, ModbusBinaryFramer): self.base_adu_size = 5 # start(1) + Address(1), CRC(2) + end(1) - else: - self.base_adu_size = -1 def _calculate_response_length(self, expected_pdu_size): - if self.base_adu_size == -1: - return None - else: - return self.base_adu_size + expected_pdu_size + return self.base_adu_size + expected_pdu_size def _calculate_exception_length(self): ''' Returns the length of the Modbus Exception Response according to @@ -94,8 +89,6 @@ def _calculate_exception_length(self): elif isinstance(self.client.framer, (ModbusRtuFramer, ModbusBinaryFramer)): return self.base_adu_size + 2 # Fcode(1), ExcecptionCode(1) - return None - def _check_response(self, response): ''' Checks if the response is a Modbus Exception. ''' @@ -208,11 +201,11 @@ def _transact(self, packet, response_length, full=False): _logger.debug("Changing transaction state from 'SENDING' " "to 'WAITING FOR REPLY'") self.client.state = ModbusTransactionState.WAITING_FOR_REPLY - result = self._recv(response_length or 1024, full) + result = self._recv(response_length, full) if _logger.isEnabledFor(logging.DEBUG): _logger.debug("RECV: " + hexlify_packets(result)) except (socket.error, ModbusIOException, - InvalidMessageRecievedException) as msg: + InvalidMessageReceivedException) as msg: self.client.close() _logger.debug("Transaction failed. (%s) " % msg) last_exception = msg @@ -223,7 +216,6 @@ def _send(self, packet): return self.client.framer.sendPacket(packet) def _recv(self, expected_response_length, full): - expected_response_length = expected_response_length or 1024 if not full: exception_length = self._calculate_exception_length() if isinstance(self.client.framer, ModbusSocketFramer): @@ -238,31 +230,37 @@ def _recv(self, expected_response_length, full): min_size = expected_response_length read_min = self.client.framer.recvPacket(min_size) - if read_min: + if not read_min: + return read_min + + if len(read_min) < min_size: + raise InvalidMessageReceivedException( + "Incomplete message received, expected at least %d bytes (%d received)" + % (min_size, len(read_min))) + + if isinstance(self.client.framer, ModbusSocketFramer): + func_code = byte2int(read_min[-1]) + elif isinstance(self.client.framer, ModbusRtuFramer): + func_code = byte2int(read_min[-1]) + elif isinstance(self.client.framer, ModbusAsciiFramer): + func_code = int(read_min[3:5], 16) + elif isinstance(self.client.framer, ModbusBinaryFramer): + func_code = byte2int(read_min[-1]) + else: + func_code = -1 + + if func_code < 0x80: # Not an error if isinstance(self.client.framer, ModbusSocketFramer): - func_code = byte2int(read_min[-1]) - elif isinstance(self.client.framer, ModbusRtuFramer): - func_code = byte2int(read_min[-1]) - elif isinstance(self.client.framer, ModbusAsciiFramer): - func_code = int(read_min[3:5], 16) - elif isinstance(self.client.framer, ModbusBinaryFramer): - func_code = byte2int(read_min[-1]) - else: - func_code = -1 - - if func_code < 0x80: # Not an error - if isinstance(self.client.framer, ModbusSocketFramer): - # Ommit UID, which is included in header size - h_size = self.client.framer._hsize - length = struct.unpack(">H", read_min[4:6])[0] - 1 - expected_response_length = h_size + length - expected_response_length -= min_size - total = expected_response_length + min_size - else: - expected_response_length = exception_length - min_size - total = expected_response_length + min_size + # Ommit UID, which is included in header size + h_size = self.client.framer._hsize + length = struct.unpack(">H", read_min[4:6])[0] - 1 + expected_response_length = h_size + length + expected_response_length -= min_size + total = expected_response_length + min_size else: - total = expected_response_length + expected_response_length = exception_length - min_size + total = expected_response_length + min_size + else: read_min = b'' total = expected_response_length @@ -273,6 +271,9 @@ def _recv(self, expected_response_length, full): _logger.debug("Incomplete message received, " "Expected {} bytes Recieved " "{} bytes !!!!".format(total, actual)) + raise InvalidMessageReceivedException( + "Incomplete message received, %d bytes expected (%d received)" + % (total, actual)) if self.client.state != ModbusTransactionState.PROCESSING_REPLY: _logger.debug("Changing transaction state from " "'WAITING FOR REPLY' to 'PROCESSING REPLY'") diff --git a/test/test_client_sync.py b/test/test_client_sync.py index d503d0dda..e13b7a762 100644 --- a/test/test_client_sync.py +++ b/test/test_client_sync.py @@ -2,9 +2,9 @@ import unittest from pymodbus.compat import IS_PYTHON3 if IS_PYTHON3: # Python 3 - from unittest.mock import patch, Mock + from unittest.mock import patch, Mock, MagicMock else: # Python 2 - from mock import patch, Mock + from mock import patch, Mock, MagicMock import socket import serial @@ -20,12 +20,13 @@ #---------------------------------------------------------------------------# class mockSocket(object): def close(self): return True - def recv(self, size): return '\x00'*size - def read(self, size): return '\x00'*size + def recv(self, size): return b'\x00'*size + def read(self, size): return b'\x00'*size def send(self, msg): return len(msg) def write(self, msg): return len(msg) - def recvfrom(self, size): return ['\x00'*size] + def recvfrom(self, size): return [b'\x00'*size] def sendto(self, msg, *args): return len(msg) + def setblocking(self, flag): return None def in_waiting(self): return None #---------------------------------------------------------------------------# @@ -80,8 +81,8 @@ def testBasicSyncUdpClient(self): client = ModbusUdpClient() client.socket = mockSocket() self.assertEqual(0, client._send(None)) - self.assertEqual(1, client._send('\x00')) - self.assertEqual('\x00', client._recv(1)) + self.assertEqual(1, client._send(b'\x00')) + self.assertEqual(b'\x00', client._recv(1)) # connect/disconnect self.assertTrue(client.connect()) @@ -129,8 +130,19 @@ def testUdpClientRecv(self): self.assertRaises(ConnectionException, lambda: client._recv(1024)) client.socket = mockSocket() - self.assertEqual('', client._recv(0)) - self.assertEqual('\x00'*4, client._recv(4)) + self.assertEqual(b'', client._recv(0)) + self.assertEqual(b'\x00'*4, client._recv(4)) + + mock_socket = MagicMock() + mock_socket.recv.side_effect = iter([b'\x00', b'\x01', b'\x02']) + client.socket = mock_socket + client.timeout = 1 + self.assertEqual(b'\x00\x01\x02', client._recv(3)) + mock_socket.recv.side_effect = iter([b'\x00', b'\x01', b'\x02']) + self.assertEqual(b'\x00\x01', client._recv(2)) + + mock_socket.recv.side_effect = socket.error('No data') + self.assertEqual(b'', client._recv(2)) #-----------------------------------------------------------------------# # Test TCP Client @@ -147,8 +159,8 @@ def testBasicSyncTcpClient(self): client = ModbusTcpClient() client.socket = mockSocket() self.assertEqual(0, client._send(None)) - self.assertEqual(1, client._send('\x00')) - self.assertEqual('\x00', client._recv(1)) + self.assertEqual(1, client._send(b'\x00')) + self.assertEqual(b'\x00', client._recv(1)) # connect/disconnect self.assertTrue(client.connect()) @@ -187,9 +199,20 @@ def testTcpClientRecv(self): self.assertRaises(ConnectionException, lambda: client._recv(1024)) client.socket = mockSocket() - self.assertEqual('', client._recv(0)) - self.assertEqual('\x00'*4, client._recv(4)) - + self.assertEqual(b'', client._recv(0)) + self.assertEqual(b'\x00'*4, client._recv(4)) + + mock_socket = MagicMock() + mock_socket.recv.side_effect = iter([b'\x00', b'\x01', b'\x02']) + client.socket = mock_socket + client.timeout = 1 + self.assertEqual(b'\x00\x01\x02', client._recv(3)) + mock_socket.recv.side_effect = iter([b'\x00', b'\x01', b'\x02']) + self.assertEqual(b'\x00\x01', client._recv(2)) + + mock_socket.recv.side_effect = socket.error('No data') + self.assertEqual(b'', client._recv(2)) + #-----------------------------------------------------------------------# # Test Serial Client #-----------------------------------------------------------------------# @@ -217,14 +240,14 @@ def testBasicSyncSerialClient(self, mock_serial): mock_serial.in_waiting = 0 mock_serial.write = lambda x: len(x) - mock_serial.read = lambda size: '\x00' * size + mock_serial.read = lambda size: b'\x00' * size client = ModbusSerialClient() client.socket = mock_serial client.state = 0 self.assertEqual(0, client._send(None)) client.state = 0 - self.assertEqual(1, client._send('\x00')) - self.assertEqual('\x00', client._recv(1)) + self.assertEqual(1, client._send(b'\x00')) + self.assertEqual(b'\x00', client._recv(1)) # connect/disconnect self.assertTrue(client.connect()) @@ -283,8 +306,8 @@ def testSerialClientRecv(self): self.assertRaises(ConnectionException, lambda: client._recv(1024)) client.socket = mockSocket() - self.assertEqual('', client._recv(0)) - self.assertEqual('\x00'*4, client._recv(4)) + self.assertEqual(b'', client._recv(0)) + self.assertEqual(b'\x00'*4, client._recv(4)) #---------------------------------------------------------------------------# # Main diff --git a/test/test_transaction.py b/test/test_transaction.py index 164dc9626..fd744c2eb 100644 --- a/test/test_transaction.py +++ b/test/test_transaction.py @@ -8,10 +8,13 @@ ModbusRtuFramer, ModbusBinaryFramer ) from pymodbus.factory import ServerDecoder -from pymodbus.compat import byte2int -from mock import MagicMock +from pymodbus.compat import IS_PYTHON3, byte2int +if IS_PYTHON3: # Python 3 + from unittest.mock import MagicMock, PropertyMock +else: # Python 2 + from mock import MagicMock, PropertyMock from pymodbus.exceptions import ( - NotImplementedException, ModbusIOException, InvalidMessageRecievedException + NotImplementedException, ModbusIOException, InvalidMessageReceivedException ) class ModbusTransactionTest(unittest.TestCase): @@ -41,6 +44,25 @@ def tearDown(self): del self._rtu del self._ascii + def testTransactionManagerRecv(self): + mock_client = MagicMock() + mock_client.recv.side_effect = iter([b'\x11\x04', b'\x02\x00\x0a\xf8\xf4']) + framer = ModbusRtuFramer(self.decoder, mock_client) + type(mock_client).framer = PropertyMock(return_value=framer) + manager = ModbusTransactionManager(mock_client) + self.assertEqual(manager._recv(7, False), b'\x11\x04\x02\x00\x0a\xf8\xf4') + + mock_client.recv.side_effect = iter([b'\x11']) + with self.assertRaises(InvalidMessageReceivedException): + manager._recv(7, False) + + mock_client.recv.side_effect = iter([b'']) + self.assertEqual(manager._recv(7, False), b'') + + mock_client.recv.side_effect = iter([b'\x11\x04', b'\x02']) + with self.assertRaises(InvalidMessageReceivedException): + manager._recv(7, False) + #---------------------------------------------------------------------------# # Dictionary based transaction manager #---------------------------------------------------------------------------#