Skip to content
14 changes: 6 additions & 8 deletions memphis/memphis.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,14 +325,13 @@ async def producer(self, station_name: str, producer_name: str, generate_random_
generate_random_suffix (bool): false by default, if true concatenate a random suffix to producer's name
Raises:
Exception: _description_
Exception: _description_
Returns:
_type_: _description_
"""
try:
if not self.is_connection_active:
raise MemphisError("Connection is dead")

real_name = producer_name.lower()
if generate_random_suffix:
producer_name = self.__generateRandomSuffix(producer_name)
createProducerReq = {
Expand Down Expand Up @@ -366,8 +365,8 @@ async def producer(self, station_name: str, producer_name: str, generate_random_
elif self.schema_updates_data[station_name_internal]['type'] == "graphql":
self.graphql_schemas[station_name_internal] = build_graphql_schema(
self.schema_updates_data[station_name_internal]['active_version']['schema_content'])
producer = Producer(self, producer_name, station_name)
map_key = station_name_internal+"_"+producer_name.lower()
producer = Producer(self, producer_name, station_name, real_name)
map_key = station_name_internal+"_"+real_name
self.producers_map[map_key] = producer
return producer

Expand Down Expand Up @@ -510,7 +509,6 @@ async def produce(self, station_name: str, producer_name: str, message, generate
producer = self.producers_map[map_key]
else:
producer = await self.producer(station_name=station_name, producer_name=producer_name, generate_random_suffix=generate_random_suffix)
self.producers_map[map_key] = producer
await producer.produce(message=message, ack_wait_sec=ack_wait_sec, headers=headers, async_produce= async_produce, msg_id=msg_id)
except Exception as e:
raise MemphisError(str(e)) from e
Expand Down Expand Up @@ -567,12 +565,13 @@ def get_internal_name(name: str) -> str:


class Producer:
def __init__(self, connection, producer_name: str, station_name: str):
def __init__(self, connection, producer_name: str, station_name: str, real_name: str):
self.connection = connection
self.producer_name = producer_name.lower()
self.station_name = station_name
self.internal_station_name = get_internal_name(self.station_name)
self.loop = asyncio.get_running_loop()
self.real_name = real_name

async def validate_msg(self, message):
if self.connection.schema_updates_data[self.internal_station_name] != {}:
Expand Down Expand Up @@ -677,7 +676,6 @@ async def produce(self, message, ack_wait_sec: int = 15, headers: Union[Headers,
msg_id (string, optional): Attach msg-id header to the message in order to achieve idempotency
Raises:
Exception: _description_
Exception: _description_
"""
try:
message = await self.validate_msg(message)
Expand Down Expand Up @@ -787,7 +785,7 @@ async def destroy(self):
if sub is not None:
await sub.unsubscribe()

map_key = station_name_internal+"_"+self.producer_name.lower()
map_key = station_name_internal+"_"+self.real_name
del self.connection.producers_map[map_key]

except Exception as e:
Expand Down