diff --git a/examples/common/dbstore-update-server.py b/examples/common/dbstore-update-server.py new file mode 100644 index 000000000..b0ce87bf0 --- /dev/null +++ b/examples/common/dbstore-update-server.py @@ -0,0 +1,94 @@ +''' +Pymodbus Server With Updating Thread +-------------------------------------------------------------------------- +This is an example of having a background thread updating the +context in an SQLite4 database while the server is operating. + +This scrit generates a random address range (within 0 - 65000) and a random +value and stores it in a database. It then reads the same address to verify +that the process works as expected + +This can also be done with a python thread:: + from threading import Thread + thread = Thread(target=updating_writer, args=(context,)) + thread.start() +''' +#---------------------------------------------------------------------------# +# import the modbus libraries we need +#---------------------------------------------------------------------------# +from pymodbus.server.async import StartTcpServer +from pymodbus.device import ModbusDeviceIdentification +from pymodbus.datastore import ModbusSequentialDataBlock +from pymodbus.datastore import ModbusServerContext +from pymodbus.datastore.database import SqlSlaveContext +from pymodbus.transaction import ModbusRtuFramer, ModbusAsciiFramer +import random + +#---------------------------------------------------------------------------# +# import the twisted libraries we need +#---------------------------------------------------------------------------# +from twisted.internet.task import LoopingCall + +#---------------------------------------------------------------------------# +# configure the service logging +#---------------------------------------------------------------------------# +import logging +logging.basicConfig() +log = logging.getLogger() +log.setLevel(logging.DEBUG) + +#---------------------------------------------------------------------------# +# define your callback process +#---------------------------------------------------------------------------# +def updating_writer(a): + ''' A worker process that runs every so often and + updates live values of the context which resides in an SQLite3 database. + It should be noted that there is a race condition for the update. + :param arguments: The input arguments to the call + ''' + log.debug("Updating the database context") + context = a[0] + readfunction = 0x03 # read holding registers + writefunction = 0x10 + slave_id = 0x01 # slave address + count = 50 + + # import pdb; pdb.set_trace() + + rand_value = random.randint(0, 9999) + rand_addr = random.randint(0, 65000) + log.debug("Writing to datastore: {}, {}".format(rand_addr, rand_value)) + # import pdb; pdb.set_trace() + context[slave_id].setValues(writefunction, rand_addr, [rand_value]) + values = context[slave_id].getValues(readfunction, rand_addr, count) + log.debug("Values from datastore: " + str(values)) + + + +#---------------------------------------------------------------------------# +# initialize your data store +#---------------------------------------------------------------------------# +block = ModbusSequentialDataBlock(0x00, [0]*0xff) +store = SqlSlaveContext(block) + +context = ModbusServerContext(slaves={1: store}, single=False) + + +#---------------------------------------------------------------------------# +# initialize the server information +#---------------------------------------------------------------------------# +identity = ModbusDeviceIdentification() +identity.VendorName = 'pymodbus' +identity.ProductCode = 'PM' +identity.VendorUrl = 'http://github.com/bashwork/pymodbus/' +identity.ProductName = 'pymodbus Server' +identity.ModelName = 'pymodbus Server' +identity.MajorMinorRevision = '1.0' + +#---------------------------------------------------------------------------# +# run the server you want +#---------------------------------------------------------------------------# +time = 5 # 5 seconds delay +loop = LoopingCall(f=updating_writer, a=(context,)) +loop.start(time, now=False) # initially delay by time +StartTcpServer(context, identity=identity, address=("", 5020)) diff --git a/examples/contrib/message-generator.py b/examples/contrib/message-generator.py index b9a1e8f0a..51146434b 100755 --- a/examples/contrib/message-generator.py +++ b/examples/contrib/message-generator.py @@ -12,6 +12,7 @@ * binary - `./generate-messages.py -f binary -m tx -b` ''' from optparse import OptionParser +import codecs as c #--------------------------------------------------------------------------# # import all the available framers #--------------------------------------------------------------------------# @@ -30,6 +31,7 @@ from pymodbus.mei_message import * from pymodbus.register_read_message import * from pymodbus.register_write_message import * +from pymodbus.compat import IS_PYTHON3 #--------------------------------------------------------------------------# # initialize logging @@ -51,17 +53,17 @@ WriteSingleRegisterRequest, WriteSingleCoilRequest, ReadWriteMultipleRegistersRequest, - + ReadExceptionStatusRequest, GetCommEventCounterRequest, GetCommEventLogRequest, ReportSlaveIdRequest, - + ReadFileRecordRequest, WriteFileRecordRequest, MaskWriteRegisterRequest, ReadFifoQueueRequest, - + ReadDeviceInformationRequest, ReturnQueryDataRequest, @@ -97,7 +99,7 @@ WriteSingleRegisterResponse, WriteSingleCoilResponse, ReadWriteMultipleRegistersResponse, - + ReadExceptionStatusResponse, GetCommEventCounterResponse, GetCommEventLogResponse, @@ -149,13 +151,13 @@ 'write_registers' : [0x01] * 8, 'transaction' : 0x01, 'protocol' : 0x00, - 'unit' : 0x01, + 'unit' : 0xff, } -#---------------------------------------------------------------------------# +#---------------------------------------------------------------------------# # generate all the requested messages -#---------------------------------------------------------------------------# +#---------------------------------------------------------------------------# def generate_messages(framer, options): ''' A helper method to parse the command line options @@ -168,13 +170,16 @@ def generate_messages(framer, options): print ("%-44s = " % message.__class__.__name__) packet = framer.buildPacket(message) if not options.ascii: - packet = packet.encode('hex') + '\n' - print (packet) # because ascii ends with a \r\n + if not IS_PYTHON3: + packet = packet.encode('hex') + else: + packet = c.encode(packet, 'hex_codec').decode('utf-8') + print ("{}\n".format(packet)) # because ascii ends with a \r\n -#---------------------------------------------------------------------------# +#---------------------------------------------------------------------------# # initialize our program settings -#---------------------------------------------------------------------------# +#---------------------------------------------------------------------------# def get_options(): ''' A helper method to parse the command line options diff --git a/examples/contrib/message-parser.py b/examples/contrib/message-parser.py index b5c653bf3..be8fc8b42 100755 --- a/examples/contrib/message-parser.py +++ b/examples/contrib/message-parser.py @@ -11,7 +11,7 @@ * rtu * binary ''' -#---------------------------------------------------------------------------# +#---------------------------------------------------------------------------# # import needed libraries #---------------------------------------------------------------------------# from __future__ import print_function @@ -19,12 +19,16 @@ import collections import textwrap from optparse import OptionParser +import codecs as c + from pymodbus.utilities import computeCRC, computeLRC from pymodbus.factory import ClientDecoder, ServerDecoder from pymodbus.transaction import ModbusSocketFramer from pymodbus.transaction import ModbusBinaryFramer from pymodbus.transaction import ModbusAsciiFramer from pymodbus.transaction import ModbusRtuFramer +from pymodbus.compat import byte2int, int2byte, IS_PYTHON3 + #--------------------------------------------------------------------------# # Logging @@ -33,9 +37,9 @@ modbus_log = logging.getLogger("pymodbus") -#---------------------------------------------------------------------------# +#---------------------------------------------------------------------------# # build a quick wrapper around the framers -#---------------------------------------------------------------------------# +#---------------------------------------------------------------------------# class Decoder(object): def __init__(self, framer, encode=False): @@ -52,7 +56,10 @@ def decode(self, message): :param message: The messge to decode ''' - value = message if self.encode else message.encode('hex') + if IS_PYTHON3: + value = message if self.encode else c.encode(message, 'hex_codec') + else: + value = message if self.encode else message.encode('hex') print("="*80) print("Decoding Message %s" % value) print("="*80) @@ -64,7 +71,7 @@ def decode(self, message): print("%s" % decoder.decoder.__class__.__name__) print("-"*80) try: - decoder.addToFrame(message.encode()) + decoder.addToFrame(message) if decoder.checkFrame(): decoder.advanceFrame() decoder.processIncomingPacket(message, self.report) @@ -86,7 +93,7 @@ def report(self, message): :param message: The message to print ''' print("%-15s = %s" % ('name', message.__class__.__name__)) - for k,v in message.__dict__.iteritems(): + for (k, v) in message.__dict__.items(): if isinstance(v, dict): print("%-15s =" % k) for kk,vv in v.items(): @@ -102,9 +109,9 @@ def report(self, message): print("%-15s = %s" % ('documentation', message.__doc__)) -#---------------------------------------------------------------------------# +#---------------------------------------------------------------------------# # and decode our message -#---------------------------------------------------------------------------# +#---------------------------------------------------------------------------# def get_options(): ''' A helper method to parse the command line options @@ -136,6 +143,10 @@ def get_options(): help="The file containing messages to parse", dest="file", default=None) + parser.add_option("-t", "--transaction", + help="If the incoming message is in hexadecimal format", + action="store_true", dest="transaction", default=False) + (opt, arg) = parser.parse_args() if not opt.message and len(arg) > 0: @@ -150,8 +161,19 @@ def get_messages(option): :returns: The message iterator to parse ''' if option.message: + if option.transaction: + msg = "" + for segment in option.message.split(): + segment = segment.replace("0x", "") + segment = "0" + segment if len(segment) == 1 else segment + msg = msg + segment + option.message = msg + if not option.ascii: - option.message = option.message.decode('hex') + if not IS_PYTHON3: + option.message = option.message.decode('hex') + else: + option.message = c.decode(option.message.encode(), 'hex_codec') yield option.message elif option.file: with open(option.file, "r") as handle: diff --git a/pymodbus/datastore/database/__init__.py b/pymodbus/datastore/database/__init__.py new file mode 100644 index 000000000..dbb2609a4 --- /dev/null +++ b/pymodbus/datastore/database/__init__.py @@ -0,0 +1,7 @@ +from pymodbus.datastore.database.sql_datastore import SqlSlaveContext +from pymodbus.datastore.database.redis_datastore import RedisSlaveContext + +#---------------------------------------------------------------------------# +# Exported symbols +#---------------------------------------------------------------------------# +__all__ = ["SqlSlaveContext", "RedisSlaveContext"] diff --git a/examples/contrib/redis-datastore.py b/pymodbus/datastore/database/redis_datastore.py similarity index 71% rename from examples/contrib/redis-datastore.py rename to pymodbus/datastore/database/redis_datastore.py index ef44c6544..b7c74b013 100644 --- a/examples/contrib/redis-datastore.py +++ b/pymodbus/datastore/database/redis_datastore.py @@ -29,7 +29,7 @@ def __init__(self, **kwargs): port = kwargs.get('port', 6379) self.prefix = kwargs.get('prefix', 'pymodbus') self.client = kwargs.get('client', redis.Redis(host=host, port=port)) - self.__build_mapping() + self._build_mapping() def __str__(self): ''' Returns a string representation of the context @@ -52,7 +52,7 @@ def validate(self, fx, address, count=1): ''' address = address + 1 # section 4.4 of specification _logger.debug("validate[%d] %d:%d" % (fx, address, count)) - return self.__val_callbacks[self.decode(fx)](address, count) + return self._val_callbacks[self.decode(fx)](address, count) def getValues(self, fx, address, count=1): ''' Validates the request to make sure it is in range @@ -64,7 +64,7 @@ def getValues(self, fx, address, count=1): ''' address = address + 1 # section 4.4 of specification _logger.debug("getValues[%d] %d:%d" % (fx, address, count)) - return self.__get_callbacks[self.decode(fx)](address, count) + return self._get_callbacks[self.decode(fx)](address, count) def setValues(self, fx, address, values): ''' Sets the datastore with the supplied values @@ -75,12 +75,12 @@ def setValues(self, fx, address, values): ''' address = address + 1 # section 4.4 of specification _logger.debug("setValues[%d] %d:%d" % (fx, address, len(values))) - self.__set_callbacks[self.decode(fx)](address, values) + self._set_callbacks[self.decode(fx)](address, values) #--------------------------------------------------------------------------# # Redis Helper Methods #--------------------------------------------------------------------------# - def __get_prefix(self, key): + def _get_prefix(self, key): ''' This is a helper to abstract getting bit values :param key: The key prefix to use @@ -88,52 +88,52 @@ def __get_prefix(self, key): ''' return "%s:%s" % (self.prefix, key) - def __build_mapping(self): + def _build_mapping(self): ''' A quick helper method to build the function code mapper. ''' - self.__val_callbacks = { - 'd' : lambda o, c: self.__val_bit('d', o, c), - 'c' : lambda o, c: self.__val_bit('c', o, c), - 'h' : lambda o, c: self.__val_reg('h', o, c), - 'i' : lambda o, c: self.__val_reg('i', o, c), + self._val_callbacks = { + 'd' : lambda o, c: self._val_bit('d', o, c), + 'c' : lambda o, c: self._val_bit('c', o, c), + 'h' : lambda o, c: self._val_reg('h', o, c), + 'i' : lambda o, c: self._val_reg('i', o, c), } - self.__get_callbacks = { - 'd' : lambda o, c: self.__get_bit('d', o, c), - 'c' : lambda o, c: self.__get_bit('c', o, c), - 'h' : lambda o, c: self.__get_reg('h', o, c), - 'i' : lambda o, c: self.__get_reg('i', o, c), + self._get_callbacks = { + 'd' : lambda o, c: self._get_bit('d', o, c), + 'c' : lambda o, c: self._get_bit('c', o, c), + 'h' : lambda o, c: self._get_reg('h', o, c), + 'i' : lambda o, c: self._get_reg('i', o, c), } - self.__set_callbacks = { - 'd' : lambda o, v: self.__set_bit('d', o, v), - 'c' : lambda o, v: self.__set_bit('c', o, v), - 'h' : lambda o, v: self.__set_reg('h', o, v), - 'i' : lambda o, v: self.__set_reg('i', o, v), + self._set_callbacks = { + 'd' : lambda o, v: self._set_bit('d', o, v), + 'c' : lambda o, v: self._set_bit('c', o, v), + 'h' : lambda o, v: self._set_reg('h', o, v), + 'i' : lambda o, v: self._set_reg('i', o, v), } #--------------------------------------------------------------------------# # Redis discrete implementation #--------------------------------------------------------------------------# - __bit_size = 16 - __bit_default = '\x00' * (__bit_size % 8) + _bit_size = 16 + _bit_default = '\x00' * (_bit_size % 8) - def __get_bit_values(self, key, offset, count): + def _get_bit_values(self, key, offset, count): ''' This is a helper to abstract getting bit values :param key: The key prefix to use :param offset: The address offset to start at :param count: The number of bits to read ''' - key = self.__get_prefix(key) - s = divmod(offset, self.__bit_size)[0] - e = divmod(offset + count, self.__bit_size)[0] + key = self._get_prefix(key) + s = divmod(offset, self._bit_size)[0] + e = divmod(offset + count, self._bit_size)[0] request = ('%s:%s' % (key, v) for v in range(s, e + 1)) response = self.client.mget(request) return response - def __val_bit(self, key, offset, count): + def _val_bit(self, key, offset, count): ''' Validates that the given range is currently set in redis. If any of the keys return None, then it is invalid. @@ -141,23 +141,23 @@ def __val_bit(self, key, offset, count): :param offset: The address offset to start at :param count: The number of bits to read ''' - response = self.__get_bit_values(key, offset, count) - return None not in response + response = self._get_bit_values(key, offset, count) + return True if None not in response else False - def __get_bit(self, key, offset, count): + def _get_bit(self, key, offset, count): ''' :param key: The key prefix to use :param offset: The address offset to start at :param count: The number of bits to read ''' - response = self.__get_bit_values(key, offset, count) - response = (r or self.__bit_default for r in response) + response = self._get_bit_values(key, offset, count) + response = (r or self._bit_default for r in response) result = ''.join(response) result = unpack_bitstring(result) return result[offset:offset + count] - def __set_bit(self, key, offset, values): + def _set_bit(self, key, offset, values): ''' :param key: The key prefix to use @@ -165,17 +165,17 @@ def __set_bit(self, key, offset, values): :param values: The values to set ''' count = len(values) - s = divmod(offset, self.__bit_size)[0] - e = divmod(offset + count, self.__bit_size)[0] + s = divmod(offset, self._bit_size)[0] + e = divmod(offset + count, self._bit_size)[0] value = pack_bitstring(values) - current = self.__get_bit_values(key, offset, count) - current = (r or self.__bit_default for r in current) + current = self._get_bit_values(key, offset, count) + current = (r or self._bit_default for r in current) current = ''.join(current) - current = current[0:offset] + value + current[offset + count:] - final = (current[s:s + self.__bit_size] for s in range(0, count, self.__bit_size)) + current = current[0:offset] + value.decode('utf-8') + current[offset + count:] + final = (current[s:s + self._bit_size] for s in range(0, count, self._bit_size)) - key = self.__get_prefix(key) + key = self._get_prefix(key) request = ('%s:%s' % (key, v) for v in range(s, e + 1)) request = dict(zip(request, final)) self.client.mset(request) @@ -183,17 +183,17 @@ def __set_bit(self, key, offset, values): #--------------------------------------------------------------------------# # Redis register implementation #--------------------------------------------------------------------------# - __reg_size = 16 - __reg_default = '\x00' * (__reg_size % 8) + _reg_size = 16 + _reg_default = '\x00' * (_reg_size % 8) - def __get_reg_values(self, key, offset, count): + def _get_reg_values(self, key, offset, count): ''' This is a helper to abstract getting register values :param key: The key prefix to use :param offset: The address offset to start at :param count: The number of bits to read ''' - key = self.__get_prefix(key) + key = self._get_prefix(key) #s = divmod(offset, self.__reg_size)[0] #e = divmod(offset+count, self.__reg_size)[0] @@ -202,7 +202,7 @@ def __get_reg_values(self, key, offset, count): response = self.client.mget(request) return response - def __val_reg(self, key, offset, count): + def _val_reg(self, key, offset, count): ''' Validates that the given range is currently set in redis. If any of the keys return None, then it is invalid. @@ -210,21 +210,21 @@ def __val_reg(self, key, offset, count): :param offset: The address offset to start at :param count: The number of bits to read ''' - response = self.__get_reg_values(key, offset, count) + response = self._get_reg_values(key, offset, count) return None not in response - def __get_reg(self, key, offset, count): + def _get_reg(self, key, offset, count): ''' :param key: The key prefix to use :param offset: The address offset to start at :param count: The number of bits to read ''' - response = self.__get_reg_values(key, offset, count) - response = [r or self.__reg_default for r in response] + response = self._get_reg_values(key, offset, count) + response = [r or self._reg_default for r in response] return response[offset:offset + count] - def __set_reg(self, key, offset, values): + def _set_reg(self, key, offset, values): ''' :param key: The key prefix to use @@ -237,7 +237,7 @@ def __set_reg(self, key, offset, values): #current = self.__get_reg_values(key, offset, count) - key = self.__get_prefix(key) + key = self._get_prefix(key) request = ('%s:%s' % (key, v) for v in range(offset, count + 1)) request = dict(zip(request, values)) self.client.mset(request) diff --git a/examples/contrib/database-datastore.py b/pymodbus/datastore/database/sql_datastore.py similarity index 80% rename from examples/contrib/database-datastore.py rename to pymodbus/datastore/database/sql_datastore.py index c1c48b161..a02894251 100644 --- a/examples/contrib/database-datastore.py +++ b/pymodbus/datastore/database/sql_datastore.py @@ -17,7 +17,7 @@ #---------------------------------------------------------------------------# # Context #---------------------------------------------------------------------------# -class DatabaseSlaveContext(IModbusSlaveContext): +class SqlSlaveContext(IModbusSlaveContext): ''' This creates a modbus data model with each data access stored in its own personal block @@ -30,7 +30,7 @@ def __init__(self, *args, **kwargs): ''' self.table = kwargs.get('table', 'pymodbus') self.database = kwargs.get('database', 'sqlite:///pymodbus.db') - self.__db_create(self.table, self.database) + self._db_create(self.table, self.database) def __str__(self): ''' Returns a string representation of the context @@ -42,8 +42,7 @@ def __str__(self): def reset(self): ''' Resets all the datastores to their default values ''' self._metadata.drop_all() - self.__db_create(self.table, self.database) - raise NotImplementedException() # TODO drop table? + self._db_create(self.table, self.database) def validate(self, fx, address, count=1): ''' Validates the request to make sure it is in range @@ -55,7 +54,7 @@ def validate(self, fx, address, count=1): ''' address = address + 1 # section 4.4 of specification _logger.debug("validate[%d] %d:%d" % (fx, address, count)) - return self.__validate(self.decode(fx), address, count) + return self._validate(self.decode(fx), address, count) def getValues(self, fx, address, count=1): ''' Validates the request to make sure it is in range @@ -67,7 +66,7 @@ def getValues(self, fx, address, count=1): ''' address = address + 1 # section 4.4 of specification _logger.debug("get-values[%d] %d:%d" % (fx, address, count)) - return self.__get(self.decode(fx), address, count) + return self._get(self.decode(fx), address, count) def setValues(self, fx, address, values): ''' Sets the datastore with the supplied values @@ -78,12 +77,12 @@ def setValues(self, fx, address, values): ''' address = address + 1 # section 4.4 of specification _logger.debug("set-values[%d] %d:%d" % (fx, address, len(values))) - self.__set(self.decode(fx), address, values) + self._set(self.decode(fx), address, values) #--------------------------------------------------------------------------# # Sqlite Helper Methods #--------------------------------------------------------------------------# - def __db_create(self, table, database): + def _db_create(self, table, database): ''' A helper method to initialize the database and handles :param table: The table name to create @@ -99,9 +98,8 @@ def __db_create(self, table, database): self._table.create(checkfirst=True) self._connection = self._engine.connect() - def __get(self, type, offset, count): + def _get(self, type, offset, count): ''' - :param type: The key prefix to use :param offset: The address offset to start at :param count: The number of bits to read @@ -110,47 +108,56 @@ def __get(self, type, offset, count): query = self._table.select(and_( self._table.c.type == type, self._table.c.index >= offset, - self._table.c.index <= offset + count)) + self._table.c.index <= offset + count) + ) query = query.order_by(self._table.c.index.asc()) result = self._connection.execute(query).fetchall() return [row.value for row in result] - def __build_set(self, type, offset, values, p=''): + def _build_set(self, type, offset, values, prefix=''): ''' A helper method to generate the sql update context :param type: The key prefix to use :param offset: The address offset to start at :param values: The values to set + :param prefix: Prefix fields index and type, defaults to empty string ''' result = [] for index, value in enumerate(values): result.append({ - p + 'type' : type, - p + 'index' : offset + index, + prefix + 'type' : type, + prefix + 'index' : offset + index, 'value' : value }) return result - def __set(self, type, offset, values): + def _check(self, type, offset, values): + result = self._get(type, offset, count=1) + return False if len(result) > 0 else True + + def _set(self, type, offset, values): ''' :param key: The type prefix to use :param offset: The address offset to start at :param values: The values to set ''' - context = self.__build_set(type, offset, values) - query = self._table.insert() - result = self._connection.execute(query, context) - return result.rowcount == len(values) - - def __update(self, type, offset, values): + if self._check(type, offset, values): + context = self._build_set(type, offset, values) + query = self._table.insert() + result = self._connection.execute(query, context) + return result.rowcount == len(values) + else: + return False + + def _update(self, type, offset, values): ''' :param type: The type prefix to use :param offset: The address offset to start at :param values: The values to set ''' - context = self.__build_set(type, offset, values, p='x_') + context = self._build_set(type, offset, values, prefix='x_') query = self._table.update().values(name='value') query = query.where(and_( self._table.c.type == bindparam('x_type'), @@ -158,7 +165,7 @@ def __update(self, type, offset, values): result = self._connection.execute(query, context) return result.rowcount == len(values) - def __validate(self, key, offset, count): + def _validate(self, type, offset, count): ''' :param key: The key prefix to use :param offset: The address offset to start at diff --git a/pymodbus/server/sync.py b/pymodbus/server/sync.py index f4beeea91..78ec598bc 100644 --- a/pymodbus/server/sync.py +++ b/pymodbus/server/sync.py @@ -102,7 +102,10 @@ def handle(self): if data: if _logger.isEnabledFor(logging.DEBUG): _logger.debug(" ".join([hex(byte2int(x)) for x in data])) - unit_address = byte2int(data[0]) + if not isinstance(self.framer, ModbusBinaryFramer): + unit_address = byte2int(data[0]) + else: + unit_address = byte2int(data[1]) if unit_address in self.server.context: self.framer.processIncomingPacket(data, self.execute) except Exception as msg: @@ -273,13 +276,14 @@ def __init__(self, context, framer=None, identity=None, address=None, handler=No self.context = context or ModbusServerContext() self.control = ModbusControlBlock() self.address = address or ("", Defaults.Port) + self.handler = handler or ModbusConnectedRequestHandler self.ignore_missing_slaves = kwargs.get('ignore_missing_slaves', Defaults.IgnoreMissingSlaves) if isinstance(identity, ModbusDeviceIdentification): self.control.Identity.update(identity) socketserver.ThreadingTCPServer.__init__(self, - self.address, ModbusConnectedRequestHandler) + self.address, self.handler) def process_request(self, request, client): ''' Callback for connecting a new client thread @@ -336,13 +340,14 @@ def __init__(self, context, framer=None, identity=None, address=None, handler=No self.context = context or ModbusServerContext() self.control = ModbusControlBlock() self.address = address or ("", Defaults.Port) + self.handler = handler or ModbusDisconnectedRequestHandler self.ignore_missing_slaves = kwargs.get('ignore_missing_slaves', Defaults.IgnoreMissingSlaves) if isinstance(identity, ModbusDeviceIdentification): self.control.Identity.update(identity) socketserver.ThreadingUDPServer.__init__(self, - self.address, ModbusDisconnectedRequestHandler) + self.address, self.handler) def process_request(self, request, client): ''' Callback for connecting a new client thread diff --git a/pymodbus/transaction.py b/pymodbus/transaction.py index 1efd3c17f..cc44be438 100644 --- a/pymodbus/transaction.py +++ b/pymodbus/transaction.py @@ -461,7 +461,7 @@ def processIncomingPacket(self, data, callback): def _process(self, callback, error=False): """ - Process incoming packets irrespective error condition + Process incoming packets irrespective error condition """ data = self.getRawFrame() if error else self.getFrame() result = self.decoder.decode(data) @@ -487,7 +487,7 @@ def resetFrame(self): def getRawFrame(self): """ - Returns the complete buffer + Returns the complete buffer """ return self.__buffer @@ -922,7 +922,7 @@ def __init__(self, decoder): ''' self.__buffer = b'' self.__header = {'crc':0x0000, 'len':0, 'uid':0x00} - self.__hsize = 0x02 + self.__hsize = 0x01 self.__start = b'\x7b' # { self.__end = b'\x7d' # } self.__repeat = [b'}'[0], b'{'[0]] # python3 hack diff --git a/pymodbus/utilities.py b/pymodbus/utilities.py index e3ef421e2..a15515acf 100644 --- a/pymodbus/utilities.py +++ b/pymodbus/utilities.py @@ -86,7 +86,10 @@ def unpack_bitstring(string): byte_count = len(string) bits = [] for byte in range(byte_count): - value = byte2int(string[byte]) + if IS_PYTHON3: + value = byte2int(int(string[byte])) + else: + value = byte2int(string[byte]) for _ in range(8): bits.append((value & 1) == 1) value >>= 1 @@ -96,8 +99,8 @@ def unpack_bitstring(string): def make_byte_string(s): """ Returns byte string from a given string, python3 specific fix - :param s: - :return: + :param s: + :return: """ if IS_PYTHON3 and isinstance(s, string_types): s = s.encode() diff --git a/requirements-tests.txt b/requirements-tests.txt index 85623bb57..5c4639d1b 100644 --- a/requirements-tests.txt +++ b/requirements-tests.txt @@ -8,5 +8,8 @@ Twisted>=17.1.0 zope.interface>=4.4.0 pyasn1>=0.2.3 pycrypto>=2.6.1 +pyserial>=3.4 +redis>=2.10.5 +sqlalchemy>=1.1.15 #wsgiref>=0.1.2 cryptography>=1.8.1 \ No newline at end of file diff --git a/test/test_datastore.py b/test/test_datastore.py index b6b401517..c1d99c33c 100644 --- a/test/test_datastore.py +++ b/test/test_datastore.py @@ -1,7 +1,13 @@ #!/usr/bin/env python import unittest +import mock +from mock import MagicMock +import redis +import random from pymodbus.datastore import * from pymodbus.datastore.store import BaseModbusDataBlock +from pymodbus.datastore.database import SqlSlaveContext +from pymodbus.datastore.database import RedisSlaveContext from pymodbus.exceptions import NotImplementedException from pymodbus.exceptions import NoSuchSlaveException from pymodbus.exceptions import ParameterException @@ -113,7 +119,7 @@ def testModbusSlaveContext(self): } context = ModbusSlaveContext(**store) self.assertNotEqual(str(context), None) - + for fx in [1,2,3,4]: context.setValues(fx, 0, [True]*10) self.assertTrue(context.validate(fx, 0,10)) @@ -132,6 +138,236 @@ def _set(ctx): self.assertRaises(NoSuchSlaveException, lambda: _set(context)) self.assertRaises(NoSuchSlaveException, lambda: context[0xffff]) + +class RedisDataStoreTest(unittest.TestCase): + ''' + This is the unittest for the pymodbus.datastore.database.redis module + ''' + + def setUp(self): + self.slave = RedisSlaveContext() + + def tearDown(self): + ''' Cleans up the test environment ''' + pass + + def testStr(self): + # slave = RedisSlaveContext() + self.assertEqual(str(self.slave), "Redis Slave Context %s" % self.slave.client) + + def testReset(self): + assert isinstance(self.slave.client, redis.Redis) + self.slave.client = MagicMock() + self.slave.reset() + self.slave.client.flushall.assert_called_once_with() + + def testValCallbacksSuccess(self): + self.slave._build_mapping() + mock_count = 3 + mock_offset = 0 + self.slave.client.mset = MagicMock() + self.slave.client.mget = MagicMock(return_value=['11']) + + for key in ('d', 'c', 'h', 'i'): + self.assertTrue( + self.slave._val_callbacks[key](mock_offset, mock_count) + ) + + def testValCallbacksFailure(self): + self.slave._build_mapping() + mock_count = 3 + mock_offset = 0 + self.slave.client.mset = MagicMock() + self.slave.client.mget = MagicMock(return_value=['11', None]) + + for key in ('d', 'c', 'h', 'i'): + self.assertFalse( + self.slave._val_callbacks[key](mock_offset, mock_count) + ) + + def testGetCallbacks(self): + self.slave._build_mapping() + mock_count = 3 + mock_offset = 0 + self.slave.client.mget = MagicMock(return_value='11') + + for key in ('d', 'c'): + resp = self.slave._get_callbacks[key](mock_offset, mock_count) + self.assertEqual(resp, [True, False, False]) + + for key in ('h', 'i'): + resp = self.slave._get_callbacks[key](mock_offset, mock_count) + self.assertEqual(resp, ['1', '1']) + + def testSetCallbacks(self): + self.slave._build_mapping() + mock_values = [3] + mock_offset = 0 + self.slave.client.mset = MagicMock() + self.slave.client.mget = MagicMock() + + for key in ['c', 'd']: + self.slave._set_callbacks[key](mock_offset, [3]) + k = "pymodbus:{}:{}".format(key, mock_offset) + self.slave.client.mset.assert_called_with( + {k: '\x01'} + ) + + for key in ('h', 'i'): + self.slave._set_callbacks[key](mock_offset, [3]) + k = "pymodbus:{}:{}".format(key, mock_offset) + self.slave.client.mset.assert_called_with( + {k: mock_values[0]} + ) + + def testValidate(self): + self.slave.client.mget = MagicMock(return_value=[123]) + self.assertTrue(self.slave.validate(0x01, 3000)) + + def testSetValue(self): + self.slave.client.mset = MagicMock() + self.slave.client.mget = MagicMock() + self.assertEqual(self.slave.setValues(0x01, 1000, [12]), None) + + def testGetValue(self): + self.slave.client.mget = MagicMock(return_value=["123"]) + self.assertEqual(self.slave.getValues(0x01, 23), []) + + +class MockSqlResult(object): + def __init__(self, rowcount=0, value=0): + self.rowcount = rowcount + self.value = value + + +class SqlDataStoreTest(unittest.TestCase): + ''' + This is the unittest for the pymodbus.datastore.database.SqlSlaveContesxt + module + ''' + + def setUp(self): + self.slave = SqlSlaveContext() + self.slave._metadata.drop_all = MagicMock() + self.slave._db_create = MagicMock() + self.slave._table.select = MagicMock() + self.slave._connection = MagicMock() + + self.mock_addr = random.randint(0, 65000) + self.mock_values = random.sample(range(1, 100), 5) + self.mock_function = 0x01 + self.mock_type = 'h' + self.mock_offset = 0 + self.mock_count = 1 + + self.function_map = {2: 'd', 4: 'i'} + self.function_map.update([(i, 'h') for i in [3, 6, 16, 22, 23]]) + self.function_map.update([(i, 'c') for i in [1, 5, 15]]) + + def tearDown(self): + ''' Cleans up the test environment ''' + pass + + def testStr(self): + self.assertEqual(str(self.slave), "Modbus Slave Context") + + def testReset(self): + self.slave.reset() + + self.slave._metadata.drop_all.assert_called_once_with() + self.slave._db_create.assert_called_once_with( + self.slave.table, self.slave.database + ) + def testValidateSuccess(self): + mock_result = MockSqlResult( + rowcount=len(self.mock_values) + ) + self.slave._connection.execute = MagicMock(return_value=mock_result) + self.assertTrue(self.slave.validate( + self.mock_function, self.mock_addr, len(self.mock_values)) + ) + + def testValidateFailure(self): + wrong_count = 9 + mock_result = MockSqlResult(rowcount=len(self.mock_values)) + self.slave._connection.execute = MagicMock(return_value=mock_result) + self.assertFalse(self.slave.validate( + self.mock_function, self.mock_addr, wrong_count) + ) + + def testBuildSet(self): + mock_set = [ + { + 'index': 0, + 'type': 'h', + 'value': 11 + }, + { + 'index': 1, + 'type': 'h', + 'value': 12 + } + ] + self.assertListEqual(self.slave._build_set('h', 0, [11, 12]), mock_set) + + def testCheckSuccess(self): + mock_success_results = [1, 2, 3] + self.slave._get = MagicMock(return_value=mock_success_results) + self.assertFalse(self.slave._check('h', 0, 1)) + + def testCheckFailure(self): + mock_success_results = [] + self.slave._get = MagicMock(return_value=mock_success_results) + self.assertTrue(self.slave._check('h', 0, 1)) + + def testGetValues(self): + self.slave._get = MagicMock() + + for key, value in self.function_map.items(): + self.slave.getValues(key, self.mock_addr, self.mock_count) + self.slave._get.assert_called_with( + value, self.mock_addr + 1, self.mock_count + ) + + def testSetValues(self): + self.slave._set = MagicMock() + + for key, value in self.function_map.items(): + self.slave.setValues(key, self.mock_addr, self.mock_values) + self.slave._set.assert_called_with( + value, self.mock_addr + 1, self.mock_values + ) + + def testSet(self): + self.slave._check = MagicMock(return_value=True) + self.slave._connection.execute = MagicMock( + return_value=MockSqlResult(rowcount=len(self.mock_values)) + ) + self.assertTrue(self.slave._set( + self.mock_type, self.mock_offset, self.mock_values) + ) + + self.slave._check = MagicMock(return_value=False) + self.assertFalse( + self.slave._set(self.mock_type, self.mock_offset, self.mock_values) + ) + + def testUpdateSuccess(self): + self.slave._connection.execute = MagicMock( + return_value=MockSqlResult(rowcount=len(self.mock_values)) + ) + self.assertTrue( + self.slave._update(self.mock_type, self.mock_offset, self.mock_values) + ) + + def testUpdateFailure(self): + self.slave._connection.execute = MagicMock( + return_value=MockSqlResult(rowcount=100) + ) + self.assertFalse( + self.slave._update(self.mock_type, self.mock_offset, self.mock_values) + ) + #---------------------------------------------------------------------------# # Main #---------------------------------------------------------------------------# diff --git a/test/test_transaction.py b/test/test_transaction.py index 7a90ed165..8c5e18f47 100644 --- a/test/test_transaction.py +++ b/test/test_transaction.py @@ -32,9 +32,9 @@ def tearDown(self): del self._rtu del self._ascii - #---------------------------------------------------------------------------# + #---------------------------------------------------------------------------# # Dictionary based transaction manager - #---------------------------------------------------------------------------# + #---------------------------------------------------------------------------# def testDictTransactionManagerTID(self): ''' Test the dict transaction manager TID ''' for tid in range(1, self._manager.getNextTID() + 10): @@ -65,9 +65,9 @@ class Request: pass self._manager.delTransaction(handle.transaction_id) self.assertEqual(None, self._manager.getTransaction(handle.transaction_id)) - #---------------------------------------------------------------------------# + #---------------------------------------------------------------------------# # Queue based transaction manager - #---------------------------------------------------------------------------# + #---------------------------------------------------------------------------# def testFifoTransactionManagerTID(self): ''' Test the fifo transaction manager TID ''' for tid in range(1, self._queue_manager.getNextTID() + 10): @@ -98,7 +98,7 @@ class Request: pass self._queue_manager.delTransaction(handle.transaction_id) self.assertEqual(None, self._queue_manager.getTransaction(handle.transaction_id)) - #---------------------------------------------------------------------------# + #---------------------------------------------------------------------------# # TCP tests #---------------------------------------------------------------------------# def testTCPFramerTransactionReady(self): @@ -361,7 +361,7 @@ def testBinaryFramerTransactionReady(self): def testBinaryFramerTransactionFull(self): ''' Test a full binary frame transaction ''' msg = b'\x7b\x01\x03\x00\x00\x00\x05\x85\xC9\x7d' - pack = msg[3:-3] + pack = msg[2:-3] self._binary.addToFrame(msg) self.assertTrue(self._binary.checkFrame()) result = self._binary.getFrame() @@ -372,7 +372,7 @@ def testBinaryFramerTransactionHalf(self): ''' Test a half completed binary frame transaction ''' msg1 = b'\x7b\x01\x03\x00' msg2 = b'\x00\x00\x05\x85\xC9\x7d' - pack = msg1[3:] + msg2[:-3] + pack = msg1[2:] + msg2[:-3] self._binary.addToFrame(msg1) self.assertFalse(self._binary.checkFrame()) result = self._binary.getFrame()