In [1]:
import openai
import os
import uuid

from dotenv import find_dotenv, load_dotenv
from pydantic import BaseModel, Field
from trustcall import create_extractor

from langchain_core.messages import merge_message_runs, HumanMessage, SystemMessage
from langchain_core.runnables.config import RunnableConfig
from langchain_openai import ChatOpenAI

from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import END, MessagesState, StateGraph, START
from langgraph.store.base import BaseStore
from langgraph.store.memory import InMemoryStore

In [2]:
_ = load_dotenv(find_dotenv())
openai.api_key = os.environ['OPENAI_API_KEY']

llm = ChatOpenAI(model="gpt-3.5-turbo")

In [3]:
class Memory(BaseModel):
    content: str = Field(description="The main content of the memory. For example: User expressed interest in learning about French.")

In [4]:
trustcall_extractor = create_extractor(
    llm,
    tools=[Memory],
    tool_choice="Memory",
    enable_inserts=True,        # this allows the extractor to insert new memories
)

In [5]:
MODEL_SYSTEM_MESSAGE = """
    You are a helpful chatbot. You are designed to be a companion to a user. 

    You have a long term memory which keeps track of information you learn about the user over time.

    Current Memory (may include updated memories from this conversation): 

    {memory}
"""

TRUSTCALL_INSTRUCTION = """
    Reflect on following interaction. 

    Use the provided tools to retain any necessary memories about the user. 

    Use parallel tool calling to handle updates and insertions simultaneously:
"""

def call_model(state: MessagesState, config: RunnableConfig, store: BaseStore):
    """Load memories from the store and use them to personalize the chatbot's response.
    """
    
    user_id = config["configurable"]["user_id"]
    namespace = ("memories", user_id)
    memories = store.search(namespace)

    info = "\n".join(f"- {mem.value['content']}" for mem in memories)
    system_msg = MODEL_SYSTEM_MESSAGE.format(memory=info)

    response = llm.invoke([SystemMessage(content=system_msg)]+state["messages"])

    return {"messages": response}

def write_memory(state: MessagesState, config: RunnableConfig, store: BaseStore):
    """Reflect on the chat history and update the memory collection.
    """
    
    user_id = config["configurable"]["user_id"]
    namespace = ("memories", user_id)
    existing_items = store.search(namespace)

    # format the existing memories for the Trustcall extractor
    tool_name = "Memory"
    existing_memory = (
        [
            (existing_item.key, tool_name, existing_item.value)
            for existing_item in existing_items
        ]
            if existing_items
            else None
    )

    # merge the chat history and the instruction
    updated_messages=list(merge_message_runs(messages=[SystemMessage(content=TRUSTCALL_INSTRUCTION)] + state["messages"]))

    result = trustcall_extractor.invoke(
        {
            "messages": updated_messages, 
            "existing": existing_memory
        }
    )

    # save the memories from Trustcall to the store
    for r, rmeta in zip(result["responses"], result["response_metadata"]):
        store.put(namespace,
                  rmeta.get("json_doc_id", str(uuid.uuid4())),
                  r.model_dump(mode="json"),
            )

In [6]:
graph = StateGraph(MessagesState)
graph.add_node("call_model", call_model)
graph.add_node("write_memory", write_memory)
graph.add_edge(START, "call_model")
graph.add_edge("call_model", "write_memory")
graph.add_edge("write_memory", END)

across_thread_memory = InMemoryStore()
within_thread_memory = MemorySaver()
graph = graph.compile(checkpointer=within_thread_memory, store=across_thread_memory)


In [7]:
config = {"configurable": {"thread_id": "1", "user_id": "1"}}

input_messages = [HumanMessage(content="Hi, my name is Yoona")]
for chunk in graph.stream({"messages": input_messages}, config, stream_mode="values"):
    chunk["messages"][-1].pretty_print()


Hi, my name is Yoona

Hello Yoona! It's nice to meet you. How can I assist you today?


In [8]:
input_messages = [HumanMessage(content="I love 2nd generation k-pop group and AI enthusiast.")]
for chunk in graph.stream({"messages": input_messages}, config, stream_mode="values"):
    chunk["messages"][-1].pretty_print()


I love 2nd generation k-pop group and AI enthusiast.

That's great to hear, Yoona! It's cool that you're a fan of 2nd generation K-pop groups and interested in AI. Is there a specific group or AI topic you'd like to discuss today?


In [9]:
user_id = "1"
namespace = ("memories", user_id)
memories = across_thread_memory.search(namespace)
for m in memories:
    print(m.dict())

{'value': {'content': "User's name is Yoona"}, 'key': 'a4e0fb28-1968-48e2-bdf8-d0f19a94db54', 'namespace': ['memories', '1'], 'created_at': '2025-01-01T10:18:34.832644+00:00', 'updated_at': '2025-01-01T10:18:34.832645+00:00', 'score': None}
{'value': {'content': 'User loves 2nd generation K-pop group and is an AI enthusiast.'}, 'key': 'dcb5ea3a-8976-42fb-935a-62e73eee3cd2', 'namespace': ['memories', '1'], 'created_at': '2025-01-01T10:18:36.677389+00:00', 'updated_at': '2025-01-01T10:18:36.677394+00:00', 'score': None}


In [10]:
config = {"configurable": {"thread_id": "2", "user_id": "1"}}

input_messages = [HumanMessage(content="What book about AI do you recommend for me?")]
for chunk in graph.stream({"messages": input_messages}, config, stream_mode="values"):
    chunk["messages"][-1].pretty_print()


What book about AI do you recommend for me?

I recommend "Life 3.0: Being Human in the Age of Artificial Intelligence" by Max Tegmark. It's a thought-provoking book that explores the impact of AI on our future society and raises important questions about the development of artificial intelligence. It should be an interesting read for you as an AI enthusiast!
