Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,4 @@ cython_debug/
examples/
pytest.ini
.DS_store
test_sessions/
*.db
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,25 @@ line_dify.make_error_response = make_error_response
```


## 💾 Conversation session

Conversation sessions are managed by a database. By default, SQLite is used, but you can specify the file path or database type using `session_db_url`. For the syntax, please refer to SQLAlchemy's documentation.

Additionally, you can specify the session validity period with `session_timeout`. The default is 3600 seconds. If this period elapses since the last conversation, a new conversation thread will be created on Dify when the next conversation starts.

```python
line_dify = LineDify(
line_channel_access_token=YOUR_CHANNEL_ACCESS_TOKEN,
line_channel_secret=YOUR_CHANNEL_SECRET,
dify_api_key=DIFY_API_KEY,
dify_base_url=DIFY_BASE_URL,
dify_user=DIFY_USER,
session_db_url="sqlite:///your_sessions.db", # SQLAlchemy database url
session_timeout=1800, # Timeout in seconds
)
```


## 🐝 Debug

Set `verbose=True` to see the request and response, both from/to LINE and from/to Dify.
Expand Down
75 changes: 7 additions & 68 deletions linedify/integration.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from datetime import datetime, timezone
import json
from logging import getLogger, NullHandler
import os
from traceback import format_exc
from typing import List, Tuple
import aiohttp
Expand All @@ -13,76 +11,13 @@
ImageMessage, SendMessage, TextSendMessage
)
from .dify import DifyAgent, DifyType
from .session import ConversationSessionStore


logger = getLogger(__name__)
logger.addHandler(NullHandler())


class ConversationSession:
def __init__(self, user_id: str, conversation_id: str = None, updated_at: datetime = None) -> None:
self.user_id = user_id
self.conversation_id = conversation_id
self.updated_at = updated_at or datetime.now(timezone.utc)

def to_dict(self):
return {
"user_id": self.user_id,
"conversation_id": self.conversation_id,
"updated_at": self.updated_at.isoformat()
}

@staticmethod
def from_dict(data):
return ConversationSession(
user_id=data["user_id"],
conversation_id=data.get("conversation_id"),
updated_at=datetime.fromisoformat(data["updated_at"])
)


class ConversationSessionStore:
def __init__(self, persisit_directory: str = "sessions", timeout: float = 3600.0) -> None:
self.timeout = timeout
self.directory = persisit_directory
if not os.path.exists(persisit_directory):
os.makedirs(persisit_directory)

def _get_file_path(self, user_id: str) -> str:
return os.path.join(self.directory, f"{user_id}.json")

def _load_session(self, user_id: str) -> ConversationSession:
file_path = self._get_file_path(user_id)
if not os.path.exists(file_path):
return None

with open(file_path, "r") as file:
data = json.load(file)
return ConversationSession.from_dict(data)

def _save_session(self, session: ConversationSession) -> None:
file_path = self._get_file_path(session.user_id)
with open(file_path, "w") as file:
json.dump(session.to_dict(), file)

async def get_session(self, user_id: str) -> ConversationSession:
if not user_id:
raise Exception("user_id is required")

session = self._load_session(user_id)
if session is None or (datetime.now(timezone.utc) - session.updated_at).total_seconds() > self.timeout:
session = ConversationSession(user_id)

return session

async def set_session(self, session: ConversationSession) -> None:
if not session.user_id:
raise Exception("user_id is required")

session.updated_at = datetime.now(timezone.utc)
self._save_session(session)


class LineDifyIntegrator:
def __init__(self, *,
line_channel_access_token: str,
Expand All @@ -92,7 +27,8 @@ def __init__(self, *,
dify_user: str,
dify_type: DifyType = DifyType.Agent,
error_response: str = None,
conversation_timeout: float = 3600.0,
session_db_url: str = "sqlite:///sessions.db",
session_timeout: float = 3600.0,
verbose: bool = False
) -> None:

Expand All @@ -113,7 +49,10 @@ def __init__(self, *,
"location": self.parse_location_message
}

self.conversation_session_store = ConversationSessionStore(timeout=conversation_timeout)
self.conversation_session_store = ConversationSessionStore(
db_url=session_db_url,
timeout=session_timeout
)

# Dify
self.dify_agent = DifyAgent(
Expand Down
93 changes: 93 additions & 0 deletions linedify/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from datetime import datetime, timezone
from typing import List
from sqlalchemy import create_engine, Column, String, DateTime, UniqueConstraint
from sqlalchemy.orm import declarative_base, sessionmaker


class ConversationSession:
def __init__(self, user_id: str, conversation_id: str = None, updated_at: datetime = None) -> None:
self.user_id = user_id
self.conversation_id = conversation_id
self.updated_at = updated_at or datetime.now(timezone.utc)

def to_dict(self):
return {
"user_id": self.user_id,
"conversation_id": self.conversation_id,
"updated_at": self.updated_at.isoformat()
}

@staticmethod
def from_dict(data):
return ConversationSession(
user_id=data["user_id"],
conversation_id=data.get("conversation_id"),
updated_at=datetime.fromisoformat(data["updated_at"])
)


Base = declarative_base()


class ConversationSessionModel(Base):
__tablename__ = "conversation_sessions"

id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
conversation_id = Column(String)
updated_at = Column(DateTime(timezone=True), nullable=False)

__table_args__ = (UniqueConstraint("user_id", "conversation_id", name="uix_user_conversation"),)


class ConversationSessionStore:
def __init__(self, db_url: str = "sqlite:///sessions.db", timeout: float = 3600.0) -> None:
self.timeout = timeout
self.engine = create_engine(db_url)
Base.metadata.create_all(self.engine)
self.Session = sessionmaker(bind=self.engine)

async def get_session(self, user_id: str) -> ConversationSession:
if not user_id:
raise Exception("user_id is required")

with self.Session() as session:
db_session = session.query(ConversationSessionModel).filter_by(user_id=user_id).order_by(ConversationSessionModel.updated_at.desc()).first()

now = datetime.now(timezone.utc)

if db_session is None or (now - db_session.updated_at.replace(tzinfo=timezone.utc)).total_seconds() > self.timeout:
return ConversationSession(user_id)

return ConversationSession(
user_id=db_session.user_id,
conversation_id=db_session.conversation_id,
updated_at=db_session.updated_at.replace(tzinfo=timezone.utc)
)

async def set_session(self, session: ConversationSession) -> None:
if not session.user_id:
raise Exception("user_id is required")

session.updated_at = datetime.now(timezone.utc)

with self.Session() as db_session:
db_session_model = ConversationSessionModel(
id=f"{session.user_id}_{session.conversation_id}",
user_id=session.user_id,
conversation_id=session.conversation_id,
updated_at=session.updated_at
)
db_session.merge(db_session_model)
db_session.commit()

async def get_user_conversations(self, user_id: str, count: int = 20) -> List[ConversationSession]:
with self.Session() as session:
db_sessions = session.query(ConversationSessionModel).filter_by(user_id=user_id).order_by(ConversationSessionModel.updated_at.desc()).limit(count)
user_conversations = [ConversationSession(
user_id=db_session.user_id,
conversation_id=db_session.conversation_id,
updated_at=db_session.updated_at.replace(tzinfo=timezone.utc)
) for db_session in db_sessions]
user_conversations.reverse()
return user_conversations
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ aiohttp==3.9.5
line-bot-sdk==3.11.0
fastapi==0.111.0
uvicorn==0.30.1
SQLAlchemy==2.0.31
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="linedify",
version="0.2.0",
version="0.2.2",
url="https://github.com/uezo/linedify",
author="uezo",
author_email="uezo@uezo.net",
Expand All @@ -12,7 +12,7 @@
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
packages=find_packages(exclude=["examples*", "tests*"]),
install_requires=["aiohttp==3.9.5", "line-bot-sdk==3.11.0", "fastapi==0.111.0", "uvicorn==0.30.1"],
install_requires=["aiohttp==3.9.5", "line-bot-sdk==3.11.0", "fastapi==0.111.0", "uvicorn==0.30.1", "SQLAlchemy==2.0.31"],
license="Apache v2",
classifiers=[
"Programming Language :: Python :: 3"
Expand Down
57 changes: 0 additions & 57 deletions tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import pytest
import asyncio
import os
from datetime import datetime, timezone
from linebot.models import TextMessage, ImageMessage, StickerMessage, LocationMessage
from linedify import LineDify, DifyType
from linedify.integration import ConversationSession, ConversationSessionStore


@pytest.fixture
Expand Down Expand Up @@ -73,57 +70,3 @@ async def test_parse_text_message(line_dify):
parsed_text, parsed_image = await line_dify.parse_location_message(text_message)
assert parsed_text == f"You received a location info from user in messenger app:\n - address: Jiyugaoka, Tokyo\n - latitude: 35.6\n - longitude: 139.6"
assert parsed_image is None


def test_conversation_session():
now = datetime.now(timezone.utc)
session = ConversationSession("user_id", "conversation_id", now)

assert session.user_id == "user_id"
assert session.conversation_id == "conversation_id"
assert session.updated_at == now

session_dict = session.to_dict()
session_dict["user_id"] == session.user_id
session_dict["conversation_id"] == session.conversation_id
session_dict["updated_at"] == now.isoformat()

session2 = ConversationSession.from_dict(session_dict)
assert session2.user_id == session.user_id
assert session2.conversation_id == session.conversation_id
assert session2.updated_at == session.updated_at


@pytest.mark.asyncio
async def test_conversation_session_store():
store = ConversationSessionStore("test_sessions", 3)
assert store.directory == "test_sessions"
assert store.timeout == 3.0
assert os.path.exists("test_sessions") is True

# New session
session = await store.get_session("user_id")
assert session.user_id == "user_id"
assert session.conversation_id is None
assert isinstance(session.updated_at, datetime)

last_updated_at = session.updated_at
session.conversation_id = "conversation_id"
await store.set_session(session)

# Successive session
session2 = await store.get_session("user_id")
assert session2.user_id == "user_id"
assert session2.conversation_id == "conversation_id"
assert session2.updated_at > last_updated_at

last_updated_at = session2.updated_at
await store.set_session(session2)

await asyncio.sleep(store.timeout + 1.0)

# Timeout
session3 = await store.get_session("user_id")
assert session3.user_id == "user_id"
assert session3.conversation_id is None
assert session3.updated_at > last_updated_at
Loading