Skip to content

Commit

Permalink
Neo4j chat message history (langchain-ai#13008)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasonjo authored and pprados committed Nov 20, 2023
1 parent f5607c7 commit ba24593
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 0 deletions.
76 changes: 76 additions & 0 deletions docs/docs/integrations/memory/neo4j_chat_message_history.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "91c6a7ef",
"metadata": {},
"source": [
"# Neo4j\n",
"\n",
"[Neo4j](https://en.wikipedia.org/wiki/Neo4j) is an open-source graph database management system, renowned for its efficient management of highly connected data. Unlike traditional databases that store data in tables, Neo4j uses a graph structure with nodes, edges, and properties to represent and store data. This design allows for high-performance queries on complex data relationships.\n",
"\n",
"This notebook goes over how to use `Neo4j` to store chat message history."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d15e3302",
"metadata": {},
"outputs": [],
"source": [
"from langchain.memory import Neo4jChatMessageHistory\n",
"\n",
"history = Neo4jChatMessageHistory(\n",
" url=\"bolt://localhost:7687\",\n",
" username=\"neo4j\",\n",
" password=\"password\",\n",
" session_id=\"session_id_1\"\n",
")\n",
"\n",
"history.add_user_message(\"hi!\")\n",
"\n",
"history.add_ai_message(\"whats up?\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "64fc465e",
"metadata": {},
"outputs": [],
"source": [
"history.messages"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8af285f8",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory
from langchain.memory.chat_message_histories.momento import MomentoChatMessageHistory
from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHistory
from langchain.memory.chat_message_histories.neo4j import Neo4jChatMessageHistory
from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory
from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory
from langchain.memory.chat_message_histories.rocksetdb import RocksetChatMessageHistory
Expand Down Expand Up @@ -48,4 +49,5 @@
"XataChatMessageHistory",
"ZepChatMessageHistory",
"UpstashRedisChatMessageHistory",
"Neo4jChatMessageHistory",
]
112 changes: 112 additions & 0 deletions libs/langchain/langchain/memory/chat_message_histories/neo4j.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from typing import List, Optional, Union

from langchain.schema import BaseChatMessageHistory
from langchain.schema.messages import BaseMessage, messages_from_dict
from langchain.utils import get_from_env


class Neo4jChatMessageHistory(BaseChatMessageHistory):
"""Chat message history stored in a Neo4j database."""

def __init__(
self,
session_id: Union[str, int],
url: Optional[str] = None,
username: Optional[str] = None,
password: Optional[str] = None,
database: str = "neo4j",
node_label: str = "Session",
window: int = 3,
):
try:
import neo4j
except ImportError:
raise ValueError(
"Could not import neo4j python package. "
"Please install it with `pip install neo4j`."
)

# Make sure session id is not null
if not session_id:
raise ValueError("Please ensure that the session_id parameter is provided")

url = get_from_env("url", "NEO4J_URI", url)
username = get_from_env("username", "NEO4J_USERNAME", username)
password = get_from_env("password", "NEO4J_PASSWORD", password)
database = get_from_env("database", "NEO4J_DATABASE", database)

self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password))
self._database = database
self._session_id = session_id
self._node_label = node_label
self._window = window

# Verify connection
try:
self._driver.verify_connectivity()
except neo4j.exceptions.ServiceUnavailable:
raise ValueError(
"Could not connect to Neo4j database. "
"Please ensure that the url is correct"
)
except neo4j.exceptions.AuthError:
raise ValueError(
"Could not connect to Neo4j database. "
"Please ensure that the username and password are correct"
)
# Create session node
self._driver.execute_query(
f"MERGE (s:`{self._node_label}` {{id:$session_id}})",
{"session_id": self._session_id},
).summary

@property
def messages(self) -> List[BaseMessage]: # type: ignore
"""Retrieve the messages from Neo4j"""
query = (
f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) "
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT*0.."
f"{self._window*2}]-() WITH p, length(p) AS length "
"ORDER BY length DESC LIMIT 1 UNWIND reverse(nodes(p)) AS node "
"RETURN {data:{content: node.content}, type:node.type} AS result"
)
records, _, _ = self._driver.execute_query(
query, {"session_id": self._session_id}
)

messages = messages_from_dict([el["result"] for el in records])
return messages

def add_message(self, message: BaseMessage) -> None:
"""Append the message to the record in Neo4j"""
query = (
f"MATCH (s:`{self._node_label}`) WHERE s.id = $session_id "
"OPTIONAL MATCH (s)-[lm:LAST_MESSAGE]->(last_message) "
"CREATE (s)-[:LAST_MESSAGE]->(new:Message) "
"SET new += {type:$type, content:$content} "
"WITH new, lm, last_message WHERE last_message IS NOT NULL "
"CREATE (last_message)-[:NEXT]->(new) "
"DELETE lm"
)
self._driver.execute_query(
query,
{
"type": message.type,
"content": message.content,
"session_id": self._session_id,
},
).summary

def clear(self) -> None:
"""Clear session memory from Neo4j"""
query = (
f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) "
"WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT]-() "
"WITH p, length(p) AS length ORDER BY length DESC LIMIT 1 "
"UNWIND nodes(p) as node DETACH DELETE node;"
)
self._driver.execute_query(query, {"session_id": self._session_id}).summary

def __del__(self) -> None:
if self._driver:
self._driver.close()
30 changes: 30 additions & 0 deletions libs/langchain/tests/integration_tests/memory/test_neo4j.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import json

from langchain.memory import ConversationBufferMemory
from langchain.memory.chat_message_histories import Neo4jChatMessageHistory
from langchain.schema.messages import _message_to_dict


def test_memory_with_message_store() -> None:
"""Test the memory with a message store."""
# setup MongoDB as a message store
message_history = Neo4jChatMessageHistory(session_id="test-session")
memory = ConversationBufferMemory(
memory_key="baz", chat_memory=message_history, return_messages=True
)

# add some messages
memory.chat_memory.add_ai_message("This is me, the AI")
memory.chat_memory.add_user_message("This is me, the human")

# get the message history from the memory store and turn it into a json
messages = memory.chat_memory.messages
messages_json = json.dumps([_message_to_dict(msg) for msg in messages])

assert "This is me, the AI" in messages_json
assert "This is me, the human" in messages_json

# remove the record from Azure Cosmos DB, so the next test run won't pick it up
memory.chat_memory.clear()

assert memory.chat_memory.messages == []

0 comments on commit ba24593

Please sign in to comment.