diff --git a/README.md b/README.md index d08e61f..0b14e1e 100644 --- a/README.md +++ b/README.md @@ -295,6 +295,28 @@ async def msg_handler(msgs, error, context): consumer.consume(msg_handler) ``` +### Fetch a single batch of messages +```python +msgs = await memphis.fetch_messages( + station_name="", + consumer_name="", + consumer_group="", # defaults to the consumer name + batch_size=10, # defaults to 10 + batch_max_time_to_wait_ms=5000, # defaults to 5000 + max_ack_time_ms=30000, # defaults to 30000 + max_msg_deliveries=10, # defaults to 10 + generate_random_suffix=False + start_consume_from_sequence=1 # start consuming from a specific sequence. defaults to 1 + last_messages=-1 # consume the last N messages, defaults to -1 (all messages in the station)) +) +``` + +### Fetch a single batch of messages after creating a consumer +```python +msgs = await consumer.fetch(batch_size=10) # defaults to 10 +``` + + ### Acknowledge a message Acknowledge a message indicates the Memphis server to not re-send the same message again to the same consumer / consumers group diff --git a/memphis/memphis.py b/memphis/memphis.py index c8fee2a..831bc6f 100644 --- a/memphis/memphis.py +++ b/memphis/memphis.py @@ -83,6 +83,7 @@ def __init__(self): self.update_configurations_sub = {} self.configuration_tasks = {} self.producers_map = dict() + self.consumers_map = dict() async def get_msgs_update_configurations(self, iterable: Iterable): try: @@ -329,6 +330,9 @@ async def close(self): if self.update_configurations_sub is not None: await self.update_configurations_sub.unsubscribe() self.producers_map.clear() + for consumer in self.consumers_map: + consumer.dls_messages.clear() + self.consumers_map.clear() except: return @@ -387,45 +391,45 @@ async def producer( if create_res["error"] != "": raise MemphisError(create_res["error"]) - station_name_internal = get_internal_name(station_name) - self.station_schemaverse_to_dls[station_name_internal] = create_res[ + internal_station_name = get_internal_name(station_name) + self.station_schemaverse_to_dls[internal_station_name] = create_res[ "schemaverse_to_dls" ] self.cluster_configurations["send_notification"] = create_res[ "send_notification" ] await self.start_listen_for_schema_updates( - station_name_internal, create_res["schema_update"] + internal_station_name, create_res["schema_update"] ) - if self.schema_updates_data[station_name_internal] != {}: + if self.schema_updates_data[internal_station_name] != {}: if ( - self.schema_updates_data[station_name_internal]["type"] + self.schema_updates_data[internal_station_name]["type"] == "protobuf" ): - self.parse_descriptor(station_name_internal) - if self.schema_updates_data[station_name_internal]["type"] == "json": - schema = self.schema_updates_data[station_name_internal][ + self.parse_descriptor(internal_station_name) + if self.schema_updates_data[internal_station_name]["type"] == "json": + schema = self.schema_updates_data[internal_station_name][ "active_version" ]["schema_content"] - self.json_schemas[station_name_internal] = json.loads(schema) + self.json_schemas[internal_station_name] = json.loads(schema) elif ( - self.schema_updates_data[station_name_internal]["type"] == "graphql" + self.schema_updates_data[internal_station_name]["type"] == "graphql" ): - self.graphql_schemas[station_name_internal] = build_graphql_schema( - self.schema_updates_data[station_name_internal][ + self.graphql_schemas[internal_station_name] = build_graphql_schema( + self.schema_updates_data[internal_station_name][ "active_version" ]["schema_content"] ) producer = Producer(self, producer_name, station_name, real_name) - map_key = station_name_internal + "_" + real_name + map_key = internal_station_name + "_" + real_name self.producers_map[map_key] = producer return producer except Exception as e: raise MemphisError(str(e)) from e - async def get_msg_schema_updates(self, station_name_internal, iterable): + async def get_msg_schema_updates(self, internal_station_name, iterable): async for msg in iterable: message = msg.data.decode("utf-8") message = json.loads(message) @@ -433,8 +437,8 @@ async def get_msg_schema_updates(self, station_name_internal, iterable): data = {} else: data = message["init"] - self.schema_updates_data[station_name_internal] = data - self.parse_descriptor(station_name_internal) + self.schema_updates_data[internal_station_name] = data + self.parse_descriptor(internal_station_name) def parse_descriptor(self, station_name): try: @@ -521,7 +525,7 @@ async def consumer( try: if not self.is_connection_active: raise MemphisError("Connection is dead") - + real_name = consumer_name.lower() if generate_random_suffix: consumer_name = self.__generateRandomSuffix(consumer_name) cg = consumer_name if not consumer_group else consumer_group @@ -563,7 +567,9 @@ async def consumer( if err_msg != "": raise MemphisError(err_msg) - return Consumer( + internal_station_name = get_internal_name(station_name) + map_key = internal_station_name + "_" + real_name + consumer = Consumer( self, station_name, consumer_name, @@ -576,6 +582,8 @@ async def consumer( start_consume_from_sequence=start_consume_from_sequence, last_messages=last_messages, ) + self.consumers_map[map_key] = consumer + return consumer except Exception as e: raise MemphisError(str(e)) from e @@ -604,8 +612,8 @@ async def produce( Exception: _description_ """ try: - station_name_internal = get_internal_name(station_name) - map_key = station_name_internal + "_" + producer_name.lower() + internal_station_name = get_internal_name(station_name) + map_key = internal_station_name + "_" + producer_name.lower() producer = None if map_key in self.producers_map: producer = self.producers_map[map_key] @@ -625,6 +633,52 @@ async def produce( except Exception as e: raise MemphisError(str(e)) from e + + async def fetch_messages( + self, + station_name: str, + consumer_name: str, + consumer_group: str = "", + batch_size: int = 10, + batch_max_time_to_wait_ms: int = 5000, + max_ack_time_ms: int = 30000, + max_msg_deliveries: int = 10, + generate_random_suffix: bool = False, + start_consume_from_sequence: int = 1, + last_messages: int = -1 + ): + """Consume a batch of messages. + Args:. + station_name (str): station name to consume messages from. + consumer_name (str): name for the consumer. + consumer_group (str, optional): consumer group name. Defaults to the consumer name. + batch_size (int, optional): pull batch size. Defaults to 10. + batch_max_time_to_wait_ms (int, optional): max time in miliseconds to wait between pulls. Defaults to 5000. + max_ack_time_ms (int, optional): max time for ack a message in miliseconds, in case a message not acked in this time period the Memphis broker will resend it. Defaults to 30000. + max_msg_deliveries (int, optional): max number of message deliveries, by default is 10. + generate_random_suffix (bool): false by default, if true concatenate a random suffix to consumer's name + start_consume_from_sequence(int, optional): start consuming from a specific sequence. defaults to 1. + last_messages: consume the last N messages, defaults to -1 (all messages in the station). + Returns: + list: Message + """ + try: + consumer = None + if not self.is_connection_active: + raise MemphisError("Cant fetch messages without being connected!") + internal_station_name = get_internal_name(station_name) + consumer_map_key = internal_station_name + "_" + consumer_name.lower() + if consumer_map_key in self.consumers_map: + consumer = self.consumers_map[consumer_map_key] + else: + consumer = await self.consumer(station_name=station_name, consumer_name=consumer_name, consumer_group=consumer_group, batch_size=batch_size, batch_max_time_to_wait_ms=batch_max_time_to_wait_ms, max_ack_time_ms=max_ack_time_ms, max_msg_deliveries=max_msg_deliveries, generate_random_suffix=generate_random_suffix, start_consume_from_sequence=start_consume_from_sequence, last_messages=last_messages) + messages = await consumer.fetch(batch_size) + if messages == None: + messages = [] + return messages + except Exception as e: + raise MemphisError(str(e)) from e + def is_connected(self): return self.broker_manager.is_connected @@ -646,17 +700,17 @@ async def destroy(self): if error != "" and not "not exist" in error: raise MemphisError(error) - station_name_internal = get_internal_name(self.name) - sub = self.connection.schema_updates_subs.get(station_name_internal) - task = self.connection.schema_tasks.get(station_name_internal) - if station_name_internal in self.connection.schema_updates_data: - del self.connection.schema_updates_data[station_name_internal] - if station_name_internal in self.connection.schema_updates_subs: - del self.connection.schema_updates_subs[station_name_internal] - if station_name_internal in self.connection.producers_per_station: - del self.connection.producers_per_station[station_name_internal] - if station_name_internal in self.connection.schema_tasks: - del self.connection.schema_tasks[station_name_internal] + internal_station_name = get_internal_name(self.name) + sub = self.connection.schema_updates_subs.get(internal_station_name) + task = self.connection.schema_tasks.get(internal_station_name) + if internal_station_name in self.connection.schema_updates_data: + del self.connection.schema_updates_data[internal_station_name] + if internal_station_name in self.connection.schema_updates_subs: + del self.connection.schema_updates_subs[internal_station_name] + if internal_station_name in self.connection.producers_per_station: + del self.connection.producers_per_station[internal_station_name] + if internal_station_name in self.connection.schema_tasks: + del self.connection.schema_tasks[internal_station_name] if task is not None: task.cancel() if sub is not None: @@ -668,6 +722,12 @@ async def destroy(self): if self.name not in k } + self.connection.consumers_map = { + k: v + for k, v in self.connection.consumers_map.items() + if self.name not in k + } + except Exception as e: raise MemphisError(str(e)) from e @@ -932,29 +992,29 @@ async def destroy(self): if error != "" and not "not exist" in error: raise Exception(error) - station_name_internal = get_internal_name(self.station_name) + internal_station_name = get_internal_name(self.station_name) producer_number = ( - self.connection.producers_per_station.get(station_name_internal) - 1 + self.connection.producers_per_station.get(internal_station_name) - 1 ) self.connection.producers_per_station[ - station_name_internal + internal_station_name ] = producer_number if producer_number == 0: - sub = self.connection.schema_updates_subs.get(station_name_internal) - task = self.connection.schema_tasks.get(station_name_internal) - if station_name_internal in self.connection.schema_updates_data: - del self.connection.schema_updates_data[station_name_internal] - if station_name_internal in self.connection.schema_updates_subs: - del self.connection.schema_updates_subs[station_name_internal] - if station_name_internal in self.connection.schema_tasks: - del self.connection.schema_tasks[station_name_internal] + sub = self.connection.schema_updates_subs.get(internal_station_name) + task = self.connection.schema_tasks.get(internal_station_name) + if internal_station_name in self.connection.schema_updates_data: + del self.connection.schema_updates_data[internal_station_name] + if internal_station_name in self.connection.schema_updates_subs: + del self.connection.schema_updates_subs[internal_station_name] + if internal_station_name in self.connection.schema_tasks: + del self.connection.schema_tasks[internal_station_name] if task is not None: task.cancel() if sub is not None: await sub.unsubscribe() - map_key = station_name_internal + "_" + self.real_name + map_key = internal_station_name + "_" + self.real_name del self.connection.producers_map[map_key] except Exception as e: @@ -965,6 +1025,7 @@ def default_error_handler(e): print("ping exception raised", e) + class Consumer: def __init__( self, @@ -995,9 +1056,13 @@ def __init__( error_callback = default_error_handler self.t_ping = asyncio.create_task(self.__ping_consumer(error_callback)) self.start_consume_from_sequence = start_consume_from_sequence - self.last_messages = last_messages self.context = {} + self.dls_messages = [] + self.dls_current_index = 0 + self.dls_callback_func = None + self.t_dls = asyncio.create_task(self.__consume_dls()) + def set_context(self, context): """Set a context (dict) that will be passed to each message handler call.""" @@ -1005,8 +1070,8 @@ def set_context(self, context): def consume(self, callback): """Consume events.""" + self.dls_callback_func = callback self.t_consume = asyncio.create_task(self.__consume(callback)) - self.t_dls = asyncio.create_task(self.__consume_dls(callback)) async def __consume(self, callback): subject = get_internal_name(self.station_name) @@ -1039,7 +1104,7 @@ async def __consume(self, callback): else: break - async def __consume_dls(self, callback): + async def __consume_dls(self): subject = get_internal_name(self.station_name) consumer_group = get_internal_name(self.consumer_group) try: @@ -1048,16 +1113,58 @@ async def __consume_dls(self, callback): subscription_name, subscription_name ) async for msg in self.consumer_dls.messages: - await callback( - [Message(msg, self.connection, self.consumer_group)], - None, - self.context, - ) + index_to_insert = self.dls_current_index + if index_to_insert>=10000: + index_to_insert%=10000 + self.dls_messages.insert(index_to_insert, Message(msg, self.connection, self.consumer_group)) + self.dls_current_index+=1 + if self.dls_callback_func != None: + await self.dls_callback_func( + [Message(msg, self.connection, self.consumer_group)], + None, + self.context, + ) except Exception as e: - print("dls", e) - await callback([], MemphisError(str(e)), self.context) + await self.dls_callback_func([], MemphisError(str(e)), self.context) return + async def fetch(self, batch_size: int = 10): + """Fetch a batch of messages.""" + messages = [] + if self.connection.is_connection_active: + try: + self.batch_size = batch_size + if len(self.dls_messages)>0: + if len(self.dls_messages) <= batch_size: + messages = self.dls_messages + self.dls_messages = [] + self.dls_current_index = 0 + else: + messages = self.dls_messages[0:batch_size] + del self.dls_messages[0:batch_size] + self.dls_current_index -= len(messages) + return messages + + durableName = "" + if self.consumer_group != "": + durableName = get_internal_name(self.consumer_group) + else: + durableName = get_internal_name(self.consumer_name) + subject = get_internal_name(self.station_name) + consumer_group = get_internal_name(self.consumer_group) + self.psub = await self.connection.broker_connection.pull_subscribe( + subject + ".final", durable=durableName + ) + msgs = await self.psub.fetch(batch_size) + for msg in msgs: + messages.append(Message(msg, self.connection, self.consumer_group)) + return messages + except Exception as e: + if not "timeout" in str(e): + raise MemphisError(str(e)) from e + else: + return messages + async def __ping_consumer(self, callback): while True: try: @@ -1092,6 +1199,10 @@ async def destroy(self): error = res.data.decode("utf-8") if error != "" and not "not exist" in error: raise MemphisError(error) + self.dls_messages.clear() + internal_station_name = get_internal_name(self.station_name) + map_key = internal_station_name + "_" + self.consumer_name.lower() + del self.connection.consumers_map[map_key] except Exception as e: raise MemphisError(str(e)) from e @@ -1109,7 +1220,7 @@ async def ack(self): except Exception as e: if ( "$memphis_pm_id" - in self.message.headers & "$memphis_pm_sequence" + in self.message.headers and "$memphis_pm_sequence" in self.message.headers ): try: