Skip to content

Commit

Permalink
Adds new model options
Browse files Browse the repository at this point in the history
  • Loading branch information
vmesel committed Apr 20, 2024
1 parent c759c9a commit 62b3522
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
19 changes: 14 additions & 5 deletions dialog_lib/db/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,13 @@ def __init__(
*args,
parent_session_id=None,
dbsession=None,
chats_model=Chat,
chat_messages_model=ChatMessages,
**kwargs,
):
self.parent_session_id = parent_session_id
self.dbsession = dbsession
self.chats_model = chats_model
self.chat_messages_model = chat_messages_model
super().__init__(*args, **kwargs)

Expand All @@ -38,14 +40,14 @@ def _create_table_if_not_exists(self) -> None:

def add_tags(self, tags: str) -> None:
"""Add tags for a given session_id/uuid on chats table"""
self.dbsession.query(Chat).where(Chat.session_id == self.session_id).update(
{Chat.tags: tags}
)
self.dbsession.query(self.chats_model).where(
self.chats_model.session_id == self.session_id
).update({getattr(self.chats_model, "tags"): tags})
self.dbsession.commit()

def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in PostgreSQL"""
message = ChatMessages(
message = self.chat_messages_model(
session_id=self.session_id, message=_message_to_dict(message)
)
if self.parent_session_id:
Expand All @@ -55,7 +57,12 @@ def add_message(self, message: BaseMessage) -> None:


def generate_memory_instance(
session_id, parent_session_id=None, dbsession=None, database_url=None
session_id,
parent_session_id=None,
dbsession=None,
database_url=None,
chats_model=Chat,
chat_messages_model=ChatMessages,
):
"""
Generate a memory instance for a given session_id
Expand All @@ -67,6 +74,8 @@ def generate_memory_instance(
parent_session_id=parent_session_id,
table_name="chat_messages",
dbsession=dbsession,
chats_model=chats_model,
chat_messages_model=chat_messages_model,
)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "dialog-lib"
version = "0.0.1.9"
version = "0.0.1.10"
description = ""
authors = ["Talkd.AI <foss@talkd.ai>"]
license = "MIT"
Expand Down

0 comments on commit 62b3522

Please sign in to comment.