From ff1cdf05200195d21a35af979541ae1a63fb162d Mon Sep 17 00:00:00 2001 From: uezo Date: Fri, 12 Jul 2024 23:01:28 +0900 Subject: [PATCH 1/4] Migrate session storage from JSON to SQLAlchemy database Refactored session storage to use SQLAlchemy instead of JSON files. - Improved scalability and enabled tracking of user conversation_id history. - Users' past conversation_id history can now be retrieved for later reference. --- .gitignore | 2 +- linedify/integration.py | 73 +++--------------------------- linedify/session.py | 93 +++++++++++++++++++++++++++++++++++++++ requirements.txt | 1 + tests/test_integration.py | 57 ------------------------ tests/test_session.py | 78 ++++++++++++++++++++++++++++++++ 6 files changed, 180 insertions(+), 124 deletions(-) create mode 100644 linedify/session.py create mode 100644 tests/test_session.py diff --git a/.gitignore b/.gitignore index d85d651..cce6885 100644 --- a/.gitignore +++ b/.gitignore @@ -162,4 +162,4 @@ cython_debug/ examples/ pytest.ini .DS_store -test_sessions/ +*.db diff --git a/linedify/integration.py b/linedify/integration.py index 0654b0d..9d736e3 100644 --- a/linedify/integration.py +++ b/linedify/integration.py @@ -13,76 +13,13 @@ ImageMessage, SendMessage, TextSendMessage ) from .dify import DifyAgent, DifyType +from .session import ConversationSessionStore logger = getLogger(__name__) logger.addHandler(NullHandler()) -class ConversationSession: - def __init__(self, user_id: str, conversation_id: str = None, updated_at: datetime = None) -> None: - self.user_id = user_id - self.conversation_id = conversation_id - self.updated_at = updated_at or datetime.now(timezone.utc) - - def to_dict(self): - return { - "user_id": self.user_id, - "conversation_id": self.conversation_id, - "updated_at": self.updated_at.isoformat() - } - - @staticmethod - def from_dict(data): - return ConversationSession( - user_id=data["user_id"], - conversation_id=data.get("conversation_id"), - updated_at=datetime.fromisoformat(data["updated_at"]) - ) - - -class ConversationSessionStore: - def __init__(self, persisit_directory: str = "sessions", timeout: float = 3600.0) -> None: - self.timeout = timeout - self.directory = persisit_directory - if not os.path.exists(persisit_directory): - os.makedirs(persisit_directory) - - def _get_file_path(self, user_id: str) -> str: - return os.path.join(self.directory, f"{user_id}.json") - - def _load_session(self, user_id: str) -> ConversationSession: - file_path = self._get_file_path(user_id) - if not os.path.exists(file_path): - return None - - with open(file_path, "r") as file: - data = json.load(file) - return ConversationSession.from_dict(data) - - def _save_session(self, session: ConversationSession) -> None: - file_path = self._get_file_path(session.user_id) - with open(file_path, "w") as file: - json.dump(session.to_dict(), file) - - async def get_session(self, user_id: str) -> ConversationSession: - if not user_id: - raise Exception("user_id is required") - - session = self._load_session(user_id) - if session is None or (datetime.now(timezone.utc) - session.updated_at).total_seconds() > self.timeout: - session = ConversationSession(user_id) - - return session - - async def set_session(self, session: ConversationSession) -> None: - if not session.user_id: - raise Exception("user_id is required") - - session.updated_at = datetime.now(timezone.utc) - self._save_session(session) - - class LineDifyIntegrator: def __init__(self, *, line_channel_access_token: str, @@ -92,7 +29,8 @@ def __init__(self, *, dify_user: str, dify_type: DifyType = DifyType.Agent, error_response: str = None, - conversation_timeout: float = 3600.0, + session_db_url: str = "sqlite:///sessions.db", + session_timeout: float = 3600.0, verbose: bool = False ) -> None: @@ -113,7 +51,10 @@ def __init__(self, *, "location": self.parse_location_message } - self.conversation_session_store = ConversationSessionStore(timeout=conversation_timeout) + self.conversation_session_store = ConversationSessionStore( + db_url=session_db_url, + timeout=session_timeout + ) # Dify self.dify_agent = DifyAgent( diff --git a/linedify/session.py b/linedify/session.py new file mode 100644 index 0000000..c7bbe22 --- /dev/null +++ b/linedify/session.py @@ -0,0 +1,93 @@ +from datetime import datetime, timezone +from typing import List +from sqlalchemy import create_engine, Column, String, DateTime, UniqueConstraint +from sqlalchemy.orm import declarative_base, sessionmaker + + +class ConversationSession: + def __init__(self, user_id: str, conversation_id: str = None, updated_at: datetime = None) -> None: + self.user_id = user_id + self.conversation_id = conversation_id + self.updated_at = updated_at or datetime.now(timezone.utc) + + def to_dict(self): + return { + "user_id": self.user_id, + "conversation_id": self.conversation_id, + "updated_at": self.updated_at.isoformat() + } + + @staticmethod + def from_dict(data): + return ConversationSession( + user_id=data["user_id"], + conversation_id=data.get("conversation_id"), + updated_at=datetime.fromisoformat(data["updated_at"]) + ) + + +Base = declarative_base() + + +class ConversationSessionModel(Base): + __tablename__ = "conversation_sessions" + + id = Column(String, primary_key=True) + user_id = Column(String, nullable=False) + conversation_id = Column(String) + updated_at = Column(DateTime(timezone=True), nullable=False) + + __table_args__ = (UniqueConstraint("user_id", "conversation_id", name="uix_user_conversation"),) + + +class ConversationSessionStore: + def __init__(self, db_url: str = "sqlite:///sessions.db", timeout: float = 3600.0) -> None: + self.timeout = timeout + self.engine = create_engine(db_url) + Base.metadata.create_all(self.engine) + self.Session = sessionmaker(bind=self.engine) + + async def get_session(self, user_id: str) -> ConversationSession: + if not user_id: + raise Exception("user_id is required") + + with self.Session() as session: + db_session = session.query(ConversationSessionModel).filter_by(user_id=user_id).order_by(ConversationSessionModel.updated_at.desc()).first() + + now = datetime.now(timezone.utc) + + if db_session is None or (now - db_session.updated_at.replace(tzinfo=timezone.utc)).total_seconds() > self.timeout: + return ConversationSession(user_id) + + return ConversationSession( + user_id=db_session.user_id, + conversation_id=db_session.conversation_id, + updated_at=db_session.updated_at.replace(tzinfo=timezone.utc) + ) + + async def set_session(self, session: ConversationSession) -> None: + if not session.user_id: + raise Exception("user_id is required") + + session.updated_at = datetime.now(timezone.utc) + + with self.Session() as db_session: + db_session_model = ConversationSessionModel( + id=f"{session.user_id}_{session.conversation_id}", + user_id=session.user_id, + conversation_id=session.conversation_id, + updated_at=session.updated_at + ) + db_session.merge(db_session_model) + db_session.commit() + + async def get_user_conversations(self, user_id: str, count: int = 20) -> list[ConversationSession]: + with self.Session() as session: + db_sessions = session.query(ConversationSessionModel).filter_by(user_id=user_id).order_by(ConversationSessionModel.updated_at.desc()).limit(count) + user_conversations = [ConversationSession( + user_id=db_session.user_id, + conversation_id=db_session.conversation_id, + updated_at=db_session.updated_at.replace(tzinfo=timezone.utc) + ) for db_session in db_sessions] + user_conversations.reverse() + return user_conversations diff --git a/requirements.txt b/requirements.txt index 87a52dd..c02f442 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ aiohttp==3.9.5 line-bot-sdk==3.11.0 fastapi==0.111.0 uvicorn==0.30.1 +SQLAlchemy==2.0.31 diff --git a/tests/test_integration.py b/tests/test_integration.py index 440896d..078119a 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,10 +1,7 @@ import pytest -import asyncio import os -from datetime import datetime, timezone from linebot.models import TextMessage, ImageMessage, StickerMessage, LocationMessage from linedify import LineDify, DifyType -from linedify.integration import ConversationSession, ConversationSessionStore @pytest.fixture @@ -73,57 +70,3 @@ async def test_parse_text_message(line_dify): parsed_text, parsed_image = await line_dify.parse_location_message(text_message) assert parsed_text == f"You received a location info from user in messenger app:\n - address: Jiyugaoka, Tokyo\n - latitude: 35.6\n - longitude: 139.6" assert parsed_image is None - - -def test_conversation_session(): - now = datetime.now(timezone.utc) - session = ConversationSession("user_id", "conversation_id", now) - - assert session.user_id == "user_id" - assert session.conversation_id == "conversation_id" - assert session.updated_at == now - - session_dict = session.to_dict() - session_dict["user_id"] == session.user_id - session_dict["conversation_id"] == session.conversation_id - session_dict["updated_at"] == now.isoformat() - - session2 = ConversationSession.from_dict(session_dict) - assert session2.user_id == session.user_id - assert session2.conversation_id == session.conversation_id - assert session2.updated_at == session.updated_at - - -@pytest.mark.asyncio -async def test_conversation_session_store(): - store = ConversationSessionStore("test_sessions", 3) - assert store.directory == "test_sessions" - assert store.timeout == 3.0 - assert os.path.exists("test_sessions") is True - - # New session - session = await store.get_session("user_id") - assert session.user_id == "user_id" - assert session.conversation_id is None - assert isinstance(session.updated_at, datetime) - - last_updated_at = session.updated_at - session.conversation_id = "conversation_id" - await store.set_session(session) - - # Successive session - session2 = await store.get_session("user_id") - assert session2.user_id == "user_id" - assert session2.conversation_id == "conversation_id" - assert session2.updated_at > last_updated_at - - last_updated_at = session2.updated_at - await store.set_session(session2) - - await asyncio.sleep(store.timeout + 1.0) - - # Timeout - session3 = await store.get_session("user_id") - assert session3.user_id == "user_id" - assert session3.conversation_id is None - assert session3.updated_at > last_updated_at diff --git a/tests/test_session.py b/tests/test_session.py new file mode 100644 index 0000000..7b50d80 --- /dev/null +++ b/tests/test_session.py @@ -0,0 +1,78 @@ +import pytest +import asyncio +from datetime import datetime, timezone +import os +from uuid import uuid4 +from linedify.session import ConversationSession, ConversationSessionStore + + +def test_conversation_session(): + now = datetime.now(timezone.utc) + session = ConversationSession("user_id", "conversation_id", now) + + assert session.user_id == "user_id" + assert session.conversation_id == "conversation_id" + assert session.updated_at == now + + session_dict = session.to_dict() + session_dict["user_id"] == session.user_id + session_dict["conversation_id"] == session.conversation_id + session_dict["updated_at"] == now.isoformat() + + session2 = ConversationSession.from_dict(session_dict) + assert session2.user_id == session.user_id + assert session2.conversation_id == session.conversation_id + assert session2.updated_at == session.updated_at + + +@pytest.mark.asyncio +async def test_conversation_session_store(): + store = ConversationSessionStore("sqlite:///test_sessions.db", 3) + assert str(store.engine.url) == "sqlite:///test_sessions.db" + assert store.timeout == 3.0 + + # New session + session = await store.get_session("user_id") + assert session.user_id == "user_id" + assert session.conversation_id is None + assert isinstance(session.updated_at, datetime) + + last_updated_at = session.updated_at + conversation_id = str(uuid4()) + session.conversation_id = conversation_id + await store.set_session(session) + + # Successive session + session2 = await store.get_session("user_id") + assert session2.user_id == "user_id" + assert session2.conversation_id == conversation_id + assert session2.updated_at > last_updated_at + + last_updated_at = session2.updated_at + await store.set_session(session2) + + await asyncio.sleep(store.timeout + 1.0) + + # Timeout + session3 = await store.get_session("user_id") + assert session3.user_id == "user_id" + assert session3.conversation_id is None + assert session3.updated_at > last_updated_at + + conversation_id3 = str(uuid4()) + session3.conversation_id = conversation_id3 + await store.set_session(session3) + + # Successive with another conversation_id + session4 = await store.get_session("user_id") + assert session4.user_id == "user_id" + assert session4.conversation_id == conversation_id3 + + # List sessions + sessions = await store.get_user_conversations("user_id") + + assert sessions[-2].user_id == "user_id" + assert sessions[-2].conversation_id == conversation_id + assert sessions[-1].user_id == "user_id" + assert sessions[-1].conversation_id == conversation_id3 + assert sessions[-1].updated_at > sessions[-2].updated_at From d5ecbf534cf2407a524db50ca64f0b477cbb3eb2 Mon Sep 17 00:00:00 2001 From: uezo Date: Fri, 12 Jul 2024 23:21:29 +0900 Subject: [PATCH 2/4] Update README.md --- README.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/README.md b/README.md index 9f3ae5b..ddcec89 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,25 @@ line_dify.make_error_response = make_error_response ``` +## 💾 Conversation session + +Conversation sessions are managed by a database. By default, SQLite is used, but you can specify the file path or database type using `session_db_url`. For the syntax, please refer to SQLAlchemy's documentation. + +Additionally, you can specify the session validity period with `session_timeout`. The default is 3600 seconds. If this period elapses since the last conversation, a new conversation thread will be created on Dify when the next conversation starts. + +```python +line_dify = LineDify( + line_channel_access_token=YOUR_CHANNEL_ACCESS_TOKEN, + line_channel_secret=YOUR_CHANNEL_SECRET, + dify_api_key=DIFY_API_KEY, + dify_base_url=DIFY_BASE_URL, + dify_user=DIFY_USER, + session_db_url="sqlite:///your_sessions.db", # SQLAlchemy database url + session_timeout=1800, # Timeout in seconds +) +``` + + ## 🐝 Debug Set `verbose=True` to see the request and response, both from/to LINE and from/to Dify. From 67fe7c57a5a9e29912f0ac5e895b904123542295 Mon Sep 17 00:00:00 2001 From: uezo Date: Fri, 12 Jul 2024 23:24:52 +0900 Subject: [PATCH 3/4] Remove unnecessary imports --- linedify/integration.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/linedify/integration.py b/linedify/integration.py index 9d736e3..be265e1 100644 --- a/linedify/integration.py +++ b/linedify/integration.py @@ -1,7 +1,5 @@ -from datetime import datetime, timezone import json from logging import getLogger, NullHandler -import os from traceback import format_exc from typing import List, Tuple import aiohttp From 45816f017bfadd526af2e91d57d43df606121906 Mon Sep 17 00:00:00 2001 From: uezo Date: Fri, 12 Jul 2024 23:35:18 +0900 Subject: [PATCH 4/4] =?UTF-8?q?Update=20for=20v0.2.2=F0=9F=90=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- linedify/session.py | 2 +- setup.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/linedify/session.py b/linedify/session.py index c7bbe22..01c43cf 100644 --- a/linedify/session.py +++ b/linedify/session.py @@ -81,7 +81,7 @@ async def set_session(self, session: ConversationSession) -> None: db_session.merge(db_session_model) db_session.commit() - async def get_user_conversations(self, user_id: str, count: int = 20) -> list[ConversationSession]: + async def get_user_conversations(self, user_id: str, count: int = 20) -> List[ConversationSession]: with self.Session() as session: db_sessions = session.query(ConversationSessionModel).filter_by(user_id=user_id).order_by(ConversationSessionModel.updated_at.desc()).limit(count) user_conversations = [ConversationSession( diff --git a/setup.py b/setup.py index 7fc11ce..64ede5a 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="linedify", - version="0.2.0", + version="0.2.2", url="https://github.com/uezo/linedify", author="uezo", author_email="uezo@uezo.net", @@ -12,7 +12,7 @@ long_description=open("README.md").read(), long_description_content_type="text/markdown", packages=find_packages(exclude=["examples*", "tests*"]), - install_requires=["aiohttp==3.9.5", "line-bot-sdk==3.11.0", "fastapi==0.111.0", "uvicorn==0.30.1"], + install_requires=["aiohttp==3.9.5", "line-bot-sdk==3.11.0", "fastapi==0.111.0", "uvicorn==0.30.1", "SQLAlchemy==2.0.31"], license="Apache v2", classifiers=[ "Programming Language :: Python :: 3"