In [None]:
%load_ext dotenv
%dotenv
%load_ext mypy_ipython
    

In [None]:
from langgraph.graph import START, END, StateGraph, add_messages, MessagesState

In [None]:
from typing_extensions import TypedDict

In [None]:
from langchain_openai.chat_models import ChatOpenAI

In [None]:
from langchain_core.messages import HumanMessage, BaseMessage, AIMessage, RemoveMessage, SystemMessage

In [None]:
from langchain_core.runnables import Runnable

In [None]:
from collections.abc import Sequence

In [None]:
from typing import Literal, Annotated

## Define the State


In [None]:
class State(MessagesState):
    summary: str

In [None]:
test_state = State()


In [None]:
test_state


In [None]:
bool(test_state.get("summary", ""))

## Define the Nodes

In [None]:
chat = ChatOpenAI(model= 'gpt-4', seed = 365, temperature = 0, max_completion_tokens = 100)

In [None]:
def ask_question(state: State) -> State:
    print(f"\n ----> ENTERING Ask question node:")
    
    question = "What is your question?"
    print(question)
    return State(messages = [AIMessage(question), HumanMessage(input())])

In [None]:
def trim_messages(state: State) -> State:
    print(f"\n-------> ENTERING trim_messages:")
    
    new_conversation = ""
    for i in state["messages"]:
        new_conversation += f"{i.type}: {i.content}\n\n"
        
    summary_instructions = f'''
Update the ongoing summary by incorporating the new lines of conversation below.  
Build upon the previous summary rather than repeating it so that the result  
reflects the most recent context and developments.


Previous Summary:
{state.get("summary", "")}

New Conversation:
{new_conversation}
'''
    
    print(summary_instructions)
    
    summary = chat.invoke([HumanMessage(summary_instructions)])
    
    remove_messages = [RemoveMessage(id = i.id) for i in state["messages"][:]]
    
    return State(messages = remove_messages, summary = summary.content)

In [None]:
def chatbot(state: State) -> State:
    print(f"\n ----> ENTERING Chatbot:")
    for i in state["messages"]:
        i.pretty_print()

    system_message = f'''
    Here's a quick summary of what's been discussed so far:
    {state.get("summary", "")}
    Keep this in mind as you answer the next question
    '''
    response = chat.invoke([SystemMessage(system_message)] + state["messages"])
    response.pretty_print()
    return State(messages = [response])
    

In [None]:
def ask_another_question(state: State) -> State:
    print(f"\n ----> ENTERING Ask another question node:")
    
    question = "Would you like another question (yes/no)?"
    print(question)
    
    return State(messages = [AIMessage(question), HumanMessage(input())])

## Routing function

In [None]:
def routing_function(state: MessagesState) -> str:
    if state["messages"][-1].content == "yes":
        return "trim_messages"
    else:
        return "__end__"
        

## Define the Graph

In [None]:
graph = StateGraph(MessagesState)

In [None]:
graph.add_node("ask_question", ask_question)
graph.add_node("chatbot", chatbot)
graph.add_node("ask_another_question", ask_another_question)
graph.add_node("trim_messages", trim_messages)

graph.add_edge(START, "ask_question")
graph.add_edge("ask_question", "chatbot")
graph.add_edge("chatbot", "ask_another_question")
graph.add_edge("trim_messages", "ask_question")
graph.add_conditional_edges(source = "ask_another_question", path=routing_function, path_map={"trim_messages": "trim_messages", "__end__": END})


In [None]:
graph_compiled = graph.compile()

In [None]:
isinstance(graph_compiled, Runnable)

In [None]:
graph_compiled

## Test the Graph

In [None]:
graph_compiled.invoke(State(messages=[]))