In [None]:
import os
from typing import Dict, List, Optional, Any, Literal

# Tavily API
from tavily import TavilyClient

# Langchain
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_classic.vectorstores import Chroma
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_classic.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.documents import Document

# LangChain Anthropic
from langchain_anthropic import ChatAnthropic

# Langgraph
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import MessagesState
from langgraph.prebuilt import ToolNode
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph.state import CompiledStateGraph
from langgraph.types import interrupt, Command

# Evaluation 
import mlflow
from ragas import EvaluationDataset
from ragas import evaluate
from ragas.llms import llm_factory
from ragas.metrics.collections import Faithfulness, FactualCorrectness, ToolCallAccuracy, SemanticSimilarity
from ragas import messages as ragas_messages
from ragas.integrations.langgraph import convert_to_ragas_messages
from ragas.dataset_schema import MultiTurnSample

# Anthropic
from anthropic import Anthropic, AsyncAnthropic

#OpenAI
from openai import OpenAI, AsyncOpenAI

# Environment Variables
from dotenv import load_dotenv
from IPython.display import Image, display

# langfuse
from langfuse import get_client, observe

#local Imports
#import healthbot_state
#import functions

In [None]:
load_dotenv("config.env")

In [None]:
langfuse = get_client()

In [None]:
base_llm = ChatOpenAI(
    model="gpt-5.1",
    temperature=0.0,
    verbosity="low",
    reasoning_effort="low",
)

small_llm = ChatOpenAI(
    model = "gpt-5-nano",
    temperature = 0.3,
    verbosity="low",
    reasoning_effort="medium",
)

quiz_grader = ChatAnthropic(
    model="claude-sonnet-4-5-20250929",
    temperature=0.1,
    )

## Create our tool node and LLM with Tools

In [None]:
tavily_client = TavilyClient()

In [None]:
@tool
@observe(as_type="tool", capture_input=True, capture_output=True)
def web_search(question:str)->Dict:
    """
    Return top search results for a given search query.
    """
    response = tavily_client.search(question)
    return response

In [None]:
tools = [web_search]

In [None]:
llm_with_tools = base_llm.bind_tools(tools)

## Create our State Schema

In [None]:
class State(MessagesState):
    topic: Optional[str]
    summary: Optional[str]
    quiz_question: Optional[str]
    patient_answer: Optional[str]
    evaluation: Optional[Dict[str, Any]]
    phase: Optional[
        Literal[
            "ask_topic",
            "searching",
            "show_summary",
            "waiting_ready",
            "quiz_generated",
            "waiting_answer",
            "evaluated",
            "waiting_restart"
        ]
    ]
    repeat_mode: bool

## Create an entrypoint node. 

This node should also be the introduction of the system to the user. 

This node will have an interrupt after to collect the topic the user wants to learn about. 

In [None]:
def entrypoint(state: State)->State:
    print("Hi Im the Healthbot Assistant, here to help you understand your medical conditions, treatment\n" \
          "options, and your post-treatment care instructions. I can answer any health related questions\n" \
          "you have, and I will ensure I to aid you in understanding your post-treatment process..\n")
    sys_message = SystemMessage(
        content=(
            "You are the Healthbot Assistant. You help patients understand their diagnoses, conditions," \
            "treatment options, and provide them post-treatment care instructions. You only answer health related " \
            "questions from the patient." \
            "Never share information that is not helpful. Helpful responses are only responses that assist users " \
            "in understanding their diagnoses and the post treatment which they should understand to have the best recovery conditions. " \
            "At no point should you ask the user for more information about anything."
        )
    )
    ai_message = AIMessage(
        content=(
            "What health topic or medical condition do you want to learn about?"
        )
    )

    messages = [sys_message, ai_message]
   
    return {"messages": messages}

## Create our Agent nodes

 - info_agent: uses a higher tier model to gather relevant data into a larger report based on the patients interest
 - summary_agent: uses a smaller model with a subset of the state["messages"] list to process less tokens and only summarize the report generated by the large model. this should somewhat limit token usage and since summarization is a simpler task, we do not need to pass the same large report through a model that costs more during inference.
 - quiz_agent: uses the same smaller model to generate our quiz, again this is a simpler task where we can save on cost.
 - 

In [None]:
def info_agent(state: State):
    raw_topic = interrupt("What topic do you want to learn about? ")
    topic = raw_topic.get('topic')
    topic_msg = HumanMessage(content=topic)
    messages = state["messages"]
    messages = messages + [topic_msg]
    ai_message = llm_with_tools.invoke(messages)
    messages = messages + [ai_message]

    return {"topic": topic, "messages": messages, "phase": "searching",}

In [None]:
def summary_agent(state: State):
    messages = state["messages"]
    last_message = state["messages"][-1]
    sys_message = SystemMessage(
            content = (
                "Please summarize the following into about 3-4 paragraphs. Be concise and provide the most important information to the patient." \
                "Do not exceed 200 words. Please make each paragraph a compact block of text with about 20 words per line and a \n\n in between each paragraph."
            )
        )
    quick_message = [sys_message] + [last_message]
    summary = small_llm.invoke(quick_message)
    messages.remove(last_message)
    messages = messages + [summary] 
    

    return{"messages": messages, "summary": summary.content, "repeat_mode": False, "phase": "show_summary"}
    

In [None]:
##################################################################################################################################################
# Going to need a router to our tool node or next node
##################################################################################################################################################
# CHANGE: router will route to entrypoint if theres no tool call
##################################################################################################################################################

def router_1(state: State):
    last_message = state["messages"][-1]
    if last_message.tool_calls:
        return "tools"
    return END

In [None]:
def quiz_agent(state: State):
    #if not state["repeat_mode"]:
    summary = state["summary"]
    messages = state["messages"]

    quiz_sys_message = SystemMessage(
        content=("Analyze the following summary:"
        f"\nHere is the summary:\n>> {summary}"
        "\n\nGenerate an open-ended quiz for a patient. The quiz should only be one question."
        "\nMake this quiz question level 2 difficulty on a scale of 0 to 5. "
        "\n0 = kindergaten level difficulty. 5 = highschool level diffuculty. " \
        "\nThis quiz question should be based only on the summary provided." \
        "\nThis quiz question will be presented to a patient. Its your goal to help the patient understand their post treatment care provided in the summary. " \
        "\nOnly generate the described one quiz question, nothing else." ))
    
    quiz_question = small_llm.invoke([quiz_sys_message]) 
    
    messages = messages + [quiz_sys_message, quiz_question]

    return {"messages": messages, "quiz_question": quiz_question.content, "phase": "waiting_answer"}
    #else:
        #return {"messages": messages, "quiz_question": quiz_question.content, "phase": "waiting_answer"}, print(quiz_question.content)
    

# Human in the loop after the quiz is presented to collect the patients answer to the quiz question

In [None]:
async def quiz_grader_agent(state: State):
    # Interrupt to gather the patients answer, reinvoking in the HITL will store the value as the patient answer
    patient_answer = interrupt("Please answer the quiz question: ")
    
    client = AsyncOpenAI()
    messages = state["messages"]
    summary = state["summary"]
    quiz_question = state["quiz_question"]
    
    evaluator_llm = llm_factory(model="gpt-5-nano", provider="openai", client=client)

    scorer = Faithfulness(llm=evaluator_llm)

    result = await scorer.ascore(
        user_input = quiz_question, 
        response = patient_answer, 
        retrieved_contexts = [summary])

    result_str = str(result)

    prompt_template = ChatPromptTemplate.from_messages([
            ("system", "You generates a summary of post treatment care for a patient. Then you created a quiz for the patient."
            "\nThis is the quiz: \n>>{quiz_question}. "
            "\n\nThe patient provided this answer to the quiz:\n{patient_answer}"
            "\nThe patients answer had a faithfullness score of {result} to the quiz question."
            "\nThe value of the faithfullness score is on a range between 0 and 1. anything that is a 0.6 and below is a fail. Anything above a 0.6 is a pass."
            "\n\nTranslate the patients score to a pass or fail grading. Output only, 'PASS' or 'FAIL' based on the score. Do not output anything else. "),
            ("human","What did I make on the quiz?"),])

    ai_message = prompt_template.invoke({
        "patient_answer": patient_answer,
        "quiz_question": quiz_question, 
        "result": result_str,
    })

    evaluation = await small_llm.ainvoke(ai_message.to_messages())

    messages = messages + [HumanMessage(content=patient_answer), ai_message.to_messages(), evaluation]
   
    print(result_str)
    print(evaluation.content)
    return {"mesages": messages, "patient_answer": patient_answer, "evaluation": evaluation.content, "phase": "evaluated"}


In [None]:
def router_2(state: State):
    repeat_mode = state["repeat_mode"]
    if repeat_mode:
        return "quiz_agent"
    return END

In [None]:
#Testing the quiz_agent, summary agent, info_gather_agent with tool calling workflow.

workflow = StateGraph(State)

workflow.add_node("entrypoint", entrypoint)
workflow.add_node("info_agent", info_agent)
workflow.add_node("tools", ToolNode([web_search]))
workflow.add_node("summary_agent", summary_agent)
workflow.add_node("quiz_agent", quiz_agent)
workflow.add_node("quiz_grader_agent", quiz_grader_agent)

workflow.add_edge(START, "entrypoint")
workflow.add_edge("entrypoint", "info_agent")

workflow.add_conditional_edges(
    source = "info_agent",
    path = router_1,
    path_map = ["tools", END]
)

workflow.add_edge("tools", "info_agent")
workflow.add_edge("info_agent", "summary_agent")
workflow.add_edge("summary_agent", "quiz_agent")
workflow.add_edge("quiz_agent", "quiz_grader_agent")

workflow.add_conditional_edges(
    source="quiz_grader_agent",
    path=router_2,
    path_map=["quiz_agent", END]
)

workflow.add_edge("quiz_grader_agent", END)


In [None]:
memory = MemorySaver()
graph = workflow.compile(
    checkpointer = memory
)

In [None]:
display(
    Image(
        graph.get_graph().draw_mermaid_png()
    )
)

### Human in the loop

Putting it all together

In [None]:
from langfuse.langchain import CallbackHandler
langfuse_handler = CallbackHandler()

In [None]:
print(langfuse_handler)

In [None]:

async def hitl_interaction_flow(graph: CompiledStateGraph, thread_id: int): 
    #Interrupt after the entrypoint node
    #topic = {"topic": human_input_topic}
    config = {"configurable": {"thread_id": thread_id}, 
              "callbacks": [langfuse_handler],
              "metadata":{
                  "session_id": str(thread_id),
                  "app": "Healthbot",
              }}

    #restart_from = graph.get_state(config)     #store a checkpoint from this step
    async def run_until_interrupt_or_end(resume_payload=None):
        stream_input = Command(resume=resume_payload) if resume_payload is not None else {}

        async for update in graph.astream(stream_input, config=config, stream_mode="updates"):
            for _, node_out in update.items():
                if isinstance(node_out, dict) and node_out.get("messages"):
                    node_out["messages"][-1].pretty_print()
            
            if "__interrupt__" in update:
                return update["__interrupt__"]
            if "interrupt" in update:
                return update["interrupt"]
        return None
        
    interrupt_payload = await run_until_interrupt_or_end()
    if interrupt_payload is None:
        return
    
    human_input_topic = input("Please input what post treatment topic you'd like to learn about: ").strip()
    resume_1 = {
        "topic": human_input_topic,
        "messages": [HumanMessage(content=human_input_topic)],
    }

    interrupt_payload = await run_until_interrupt_or_end(resume_payload=resume_1)
    if interrupt_payload is None:
        return  # graph ended early (quiz interrupt never happened)

    quiz_answer_input = input("Please provide an answer to the quiz below:\n>> ").strip()
    resume_2 = {
        "patient_answer": quiz_answer_input,
        "messages": [HumanMessage(content=quiz_answer_input)],
    }

    await run_until_interrupt_or_end(resume_payload=resume_2["patient_answer"])


In [None]:
from langfuse import observe, get_client
from dotenv import load_dotenv

load_dotenv('config.env')
langfuse = get_client()

In [None]:
await hitl_interaction_flow(
    graph=graph,
    thread_id=1
)