From ba24593da38836d52d5bd9c629d912dec825b272 Mon Sep 17 00:00:00 2001 From: Tomaz Bratanic Date: Fri, 10 Nov 2023 20:53:34 +0100 Subject: [PATCH] Neo4j chat message history (#13008) --- .../memory/neo4j_chat_message_history.ipynb | 76 ++++++++++++ .../memory/chat_message_histories/__init__.py | 2 + .../memory/chat_message_histories/neo4j.py | 112 ++++++++++++++++++ .../integration_tests/memory/test_neo4j.py | 30 +++++ 4 files changed, 220 insertions(+) create mode 100644 docs/docs/integrations/memory/neo4j_chat_message_history.ipynb create mode 100644 libs/langchain/langchain/memory/chat_message_histories/neo4j.py create mode 100644 libs/langchain/tests/integration_tests/memory/test_neo4j.py diff --git a/docs/docs/integrations/memory/neo4j_chat_message_history.ipynb b/docs/docs/integrations/memory/neo4j_chat_message_history.ipynb new file mode 100644 index 000000000000000..238beb099825eeb --- /dev/null +++ b/docs/docs/integrations/memory/neo4j_chat_message_history.ipynb @@ -0,0 +1,76 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "91c6a7ef", + "metadata": {}, + "source": [ + "# Neo4j\n", + "\n", + "[Neo4j](https://en.wikipedia.org/wiki/Neo4j) is an open-source graph database management system, renowned for its efficient management of highly connected data. Unlike traditional databases that store data in tables, Neo4j uses a graph structure with nodes, edges, and properties to represent and store data. This design allows for high-performance queries on complex data relationships.\n", + "\n", + "This notebook goes over how to use `Neo4j` to store chat message history." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d15e3302", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.memory import Neo4jChatMessageHistory\n", + "\n", + "history = Neo4jChatMessageHistory(\n", + " url=\"bolt://localhost:7687\",\n", + " username=\"neo4j\",\n", + " password=\"password\",\n", + " session_id=\"session_id_1\"\n", + ")\n", + "\n", + "history.add_user_message(\"hi!\")\n", + "\n", + "history.add_ai_message(\"whats up?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "64fc465e", + "metadata": {}, + "outputs": [], + "source": [ + "history.messages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8af285f8", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/libs/langchain/langchain/memory/chat_message_histories/__init__.py b/libs/langchain/langchain/memory/chat_message_histories/__init__.py index c0b7c544a9b0b7a..a1497e8a1222578 100644 --- a/libs/langchain/langchain/memory/chat_message_histories/__init__.py +++ b/libs/langchain/langchain/memory/chat_message_histories/__init__.py @@ -13,6 +13,7 @@ from langchain.memory.chat_message_histories.in_memory import ChatMessageHistory from langchain.memory.chat_message_histories.momento import MomentoChatMessageHistory from langchain.memory.chat_message_histories.mongodb import MongoDBChatMessageHistory +from langchain.memory.chat_message_histories.neo4j import Neo4jChatMessageHistory from langchain.memory.chat_message_histories.postgres import PostgresChatMessageHistory from langchain.memory.chat_message_histories.redis import RedisChatMessageHistory from langchain.memory.chat_message_histories.rocksetdb import RocksetChatMessageHistory @@ -48,4 +49,5 @@ "XataChatMessageHistory", "ZepChatMessageHistory", "UpstashRedisChatMessageHistory", + "Neo4jChatMessageHistory", ] diff --git a/libs/langchain/langchain/memory/chat_message_histories/neo4j.py b/libs/langchain/langchain/memory/chat_message_histories/neo4j.py new file mode 100644 index 000000000000000..dfbf75cc3049543 --- /dev/null +++ b/libs/langchain/langchain/memory/chat_message_histories/neo4j.py @@ -0,0 +1,112 @@ +from typing import List, Optional, Union + +from langchain.schema import BaseChatMessageHistory +from langchain.schema.messages import BaseMessage, messages_from_dict +from langchain.utils import get_from_env + + +class Neo4jChatMessageHistory(BaseChatMessageHistory): + """Chat message history stored in a Neo4j database.""" + + def __init__( + self, + session_id: Union[str, int], + url: Optional[str] = None, + username: Optional[str] = None, + password: Optional[str] = None, + database: str = "neo4j", + node_label: str = "Session", + window: int = 3, + ): + try: + import neo4j + except ImportError: + raise ValueError( + "Could not import neo4j python package. " + "Please install it with `pip install neo4j`." + ) + + # Make sure session id is not null + if not session_id: + raise ValueError("Please ensure that the session_id parameter is provided") + + url = get_from_env("url", "NEO4J_URI", url) + username = get_from_env("username", "NEO4J_USERNAME", username) + password = get_from_env("password", "NEO4J_PASSWORD", password) + database = get_from_env("database", "NEO4J_DATABASE", database) + + self._driver = neo4j.GraphDatabase.driver(url, auth=(username, password)) + self._database = database + self._session_id = session_id + self._node_label = node_label + self._window = window + + # Verify connection + try: + self._driver.verify_connectivity() + except neo4j.exceptions.ServiceUnavailable: + raise ValueError( + "Could not connect to Neo4j database. " + "Please ensure that the url is correct" + ) + except neo4j.exceptions.AuthError: + raise ValueError( + "Could not connect to Neo4j database. " + "Please ensure that the username and password are correct" + ) + # Create session node + self._driver.execute_query( + f"MERGE (s:`{self._node_label}` {{id:$session_id}})", + {"session_id": self._session_id}, + ).summary + + @property + def messages(self) -> List[BaseMessage]: # type: ignore + """Retrieve the messages from Neo4j""" + query = ( + f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) " + "WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT*0.." + f"{self._window*2}]-() WITH p, length(p) AS length " + "ORDER BY length DESC LIMIT 1 UNWIND reverse(nodes(p)) AS node " + "RETURN {data:{content: node.content}, type:node.type} AS result" + ) + records, _, _ = self._driver.execute_query( + query, {"session_id": self._session_id} + ) + + messages = messages_from_dict([el["result"] for el in records]) + return messages + + def add_message(self, message: BaseMessage) -> None: + """Append the message to the record in Neo4j""" + query = ( + f"MATCH (s:`{self._node_label}`) WHERE s.id = $session_id " + "OPTIONAL MATCH (s)-[lm:LAST_MESSAGE]->(last_message) " + "CREATE (s)-[:LAST_MESSAGE]->(new:Message) " + "SET new += {type:$type, content:$content} " + "WITH new, lm, last_message WHERE last_message IS NOT NULL " + "CREATE (last_message)-[:NEXT]->(new) " + "DELETE lm" + ) + self._driver.execute_query( + query, + { + "type": message.type, + "content": message.content, + "session_id": self._session_id, + }, + ).summary + + def clear(self) -> None: + """Clear session memory from Neo4j""" + query = ( + f"MATCH (s:`{self._node_label}`)-[:LAST_MESSAGE]->(last_message) " + "WHERE s.id = $session_id MATCH p=(last_message)<-[:NEXT]-() " + "WITH p, length(p) AS length ORDER BY length DESC LIMIT 1 " + "UNWIND nodes(p) as node DETACH DELETE node;" + ) + self._driver.execute_query(query, {"session_id": self._session_id}).summary + + def __del__(self) -> None: + if self._driver: + self._driver.close() diff --git a/libs/langchain/tests/integration_tests/memory/test_neo4j.py b/libs/langchain/tests/integration_tests/memory/test_neo4j.py new file mode 100644 index 000000000000000..d14e2c81f253fc7 --- /dev/null +++ b/libs/langchain/tests/integration_tests/memory/test_neo4j.py @@ -0,0 +1,30 @@ +import json + +from langchain.memory import ConversationBufferMemory +from langchain.memory.chat_message_histories import Neo4jChatMessageHistory +from langchain.schema.messages import _message_to_dict + + +def test_memory_with_message_store() -> None: + """Test the memory with a message store.""" + # setup MongoDB as a message store + message_history = Neo4jChatMessageHistory(session_id="test-session") + memory = ConversationBufferMemory( + memory_key="baz", chat_memory=message_history, return_messages=True + ) + + # add some messages + memory.chat_memory.add_ai_message("This is me, the AI") + memory.chat_memory.add_user_message("This is me, the human") + + # get the message history from the memory store and turn it into a json + messages = memory.chat_memory.messages + messages_json = json.dumps([_message_to_dict(msg) for msg in messages]) + + assert "This is me, the AI" in messages_json + assert "This is me, the human" in messages_json + + # remove the record from Azure Cosmos DB, so the next test run won't pick it up + memory.chat_memory.clear() + + assert memory.chat_memory.messages == []