Skip to content
32 changes: 18 additions & 14 deletions memphis/memphis.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain what is it for?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For typing code from below is 'legal' in python 3.11, from _future import annotations allow typing module to work on previous versions of python. This way typing is understandable same on python versions 3.7+

class C:
    @classmethod
    def from_string(cls, source: str) -> C:

New way of writing Union types it's also legal like Union[A,B] is now A | B

long read

# Credit for The NATS.IO Authors
# Copyright 2021-2022 The Memphis Authors
# Licensed under the Apache License, Version 2.0 (the “License”);
Expand All @@ -12,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Iterable, Callable, Union
import random
import json
import ssl
Expand All @@ -27,7 +29,6 @@
from jsonschema import validate
from google.protobuf import descriptor_pb2, descriptor_pool
from google.protobuf.message_factory import MessageFactory
from google.protobuf.message import Message

import memphis.retention_types as retention_types
import memphis.storage_types as storage_types
Expand All @@ -36,7 +37,7 @@


class set_interval():
def __init__(self, func, sec):
def __init__(self, func: Callable, sec: int):
def func_wrapper():
self.t = Timer(sec, func_wrapper)
self.t.start()
Expand All @@ -63,7 +64,7 @@ def __init__(self):
self.update_configurations_sub = {}
self.configuration_tasks = {}

async def get_msgs_update_configurations(self, iterable):
async def get_msgs_update_configurations(self, iterable: Iterable):
try:
async for msg in iterable:
message = msg.data.decode("utf-8")
Expand All @@ -87,7 +88,7 @@ async def configurations_listener(self):
except Exception as err:
raise MemphisError(err)

async def connect(self, host, username, connection_token, port=6666, reconnect=True, max_reconnect=10, reconnect_interval_ms=1500, timeout_ms=15000, cert_file='', key_file='', ca_file=''):
async def connect(self, host: str, username: str, connection_token: str, port: int = 6666, reconnect: bool = True, max_reconnect: int = 10, reconnect_interval_ms: int = 1500, timeout_ms: int = 15000, cert_file: str = "", key_file: str = "", ca_file: str = ""):
"""Creates connection with Memphis.
Args:
host (str): memphis host.
Expand Down Expand Up @@ -157,7 +158,8 @@ async def send_notification(self, title, msg, failedMsg, type):
msgToSend = json.dumps(msg).encode('utf-8')
await self.broker_manager.publish("$memphis_notifications", msgToSend)

async def station(self, name, retention_type=retention_types.MAX_MESSAGE_AGE_SECONDS, retention_value=604800, storage_type=storage_types.DISK, replicas=1, idempotency_window_ms=120000, schema_name="", send_poison_msg_to_dls=True, send_schema_failed_msg_to_dls=True):
async def station(self, name: str,
retention_type: str = retention_types.MAX_MESSAGE_AGE_SECONDS, retention_value: int = 604800, storage_type: str = storage_types.DISK, replicas: int = 1, idempotency_window_ms: int = 120000, schema_name: str = "", send_poison_msg_to_dls: bool = True, send_schema_failed_msg_to_dls: bool = True,):
"""Creates a station.
Args:
name (str): station name.
Expand Down Expand Up @@ -296,7 +298,7 @@ def __normalize_host(self, host):
else:
return host

async def producer(self, station_name, producer_name, generate_random_suffix=False):
async def producer(self, station_name: str, producer_name: str, generate_random_suffix: bool =False):
"""Creates a producer.
Args:
station_name (str): station name to produce messages into.
Expand Down Expand Up @@ -405,7 +407,7 @@ async def start_listen_for_schema_updates(self, station_name, schema_update_data
station_name, self.schema_updates_subs[station_name].messages))
self.schema_tasks[station_name] = task

async def consumer(self, station_name, consumer_name, consumer_group="", pull_interval_ms=1000, batch_size=10, batch_max_time_to_wait_ms=5000, max_ack_time_ms=30000, max_msg_deliveries=10, generate_random_suffix=False, start_consume_from_sequence=1, last_messages=-1):
async def consumer(self, station_name: str, consumer_name: str, consumer_group: str ="", pull_interval_ms: int = 1000, batch_size: int = 10, batch_max_time_to_wait_ms: int =5000, max_ack_time_ms: int=30000, max_msg_deliveries: int=10, generate_random_suffix: bool=False, start_consume_from_sequence: int=1, last_messages: int=-1):
"""Creates a consumer.
Args:.
station_name (str): station name to consume messages from.
Expand Down Expand Up @@ -486,7 +488,7 @@ def add(self, key, value):


class Station:
def __init__(self, connection, name):
def __init__(self, connection, name: str):
self.connection = connection
self.name = name.lower()

Expand Down Expand Up @@ -531,7 +533,7 @@ def get_internal_name(name: str) -> str:


class Producer:
def __init__(self, connection, producer_name, station_name):
def __init__(self, connection, producer_name: str, station_name: str):
self.connection = connection
self.producer_name = producer_name.lower()
self.station_name = station_name
Expand Down Expand Up @@ -628,10 +630,10 @@ def validate_graphql(self, message):
e = "Invalid message format, expected GraphQL"
raise Exception("Schema validation has failed: " + str(e))

def get_dls_msg_id(self, station_name, producer_name, unix_time):
def get_dls_msg_id(self, station_name: str, producer_name: str, unix_time: str):
return station_name + '~' + producer_name + '~0~' + unix_time

async def produce(self, message, ack_wait_sec=15, headers={}, async_produce=False, msg_id=None):
async def produce(self, message, ack_wait_sec: int = 15, headers: Union[Headers, None] = None, async_produce: bool=False, msg_id: Union[str, None]= None):
Copy link
Contributor Author

@siwikm siwikm Jan 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is example to reproduce this bug.

from random import randint
def f1(dic = {}) -> None:
    dic.update({str(randint(1, 10000)): "foo"})
    print(dic)

f1()
f1()
f1()

# prints:
# {'5494': 'foo'}
# {'5494': 'foo', '9910': 'foo'}
# {'5494': 'foo', '9910': 'foo', '245': 'foo'}

More reading for exampel here url

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this fix should break code written with the previous versions of the SDK?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's should not. Because i add condition to check if this is None.
But because of this header bug, there was chances where even in user did not use headers in produce function, but previous message had headers then we would put wrong previous headers in product without headers passed as arg.

Patching this is not hard just headers = {} to headers = None and if headers != {}: to if headers is not None: So this could be done.

# Before change
if headers != {}:

# After change
if headers is not None:

"""Produces a message into a station.
Args:
message (bytearray/dict): message to send into the station - bytearray/protobuf class (schema validated station - protobuf) or bytearray/dict (schema validated station - json schema) or string/bytearray/graphql.language.ast.DocumentNode (schema validated station - graphql schema)
Expand All @@ -650,10 +652,10 @@ async def produce(self, message, ack_wait_sec=15, headers={}, async_produce=Fals
"$memphis_producedBy": self.producer_name,
"$memphis_connectionId": self.connection.connection_id}

if msg_id != None:
if msg_id is not None:
memphis_headers["msg-id"] = msg_id

if headers != {}:
if headers is not None:
headers = headers.headers
headers.update(memphis_headers)
else:
Expand Down Expand Up @@ -760,7 +762,7 @@ async def default_error_handler(e):


class Consumer:
def __init__(self, connection, station_name, consumer_name, consumer_group, pull_interval_ms, batch_size, batch_max_time_to_wait_ms, max_ack_time_ms, max_msg_deliveries=10, error_callback=None, start_consume_from_sequence=1, last_messages=-1):
def __init__(self, connection, station_name: str, consumer_name, consumer_group, pull_interval_ms: int, batch_size: int, batch_max_time_to_wait_ms: int, max_ack_time_ms: int, max_msg_deliveries: int=10, error_callback=None, start_consume_from_sequence: int=1, last_messages: int=-1):
self.connection = connection
self.station_name = station_name.lower()
self.consumer_name = consumer_name.lower()
Expand All @@ -775,6 +777,7 @@ def __init__(self, connection, station_name, consumer_name, consumer_group, pull
error_callback = default_error_handler
self.t_ping = asyncio.create_task(self.__ping_consumer(error_callback))
self.start_consume_from_sequence = start_consume_from_sequence

self.last_messages= last_messages
self.context = {}

Expand Down Expand Up @@ -804,6 +807,7 @@ async def __consume(self, callback):
Message(msg, self.connection, self.consumer_group))
await callback(memphis_messages, None, self.context)
await asyncio.sleep(self.pull_interval_ms/1000)

except asyncio.TimeoutError:
await callback([], MemphisError("Memphis: TimeoutError"))
continue
Expand Down