In [None]:
import dotenv
dotenv.load_dotenv(override=True)

import os
import uuid
import json
from typing import List, Union, Dict
from pydantic import BaseModel, Field
from pydantic_core import to_jsonable_python
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.providers.openai import OpenAIProvider
from pydantic_ai.mcp import MCPServerStdio
from pydantic_ai.messages import (
    ToolReturnPart, 
    TextPart, 
    ModelResponse, 
    ModelMessage, 
    ModelMessagesTypeAdapter, 
    UserPromptPart,
)
import asyncio
import nest_asyncio
nest_asyncio.apply()

In [6]:
# Initialize Openai-like model
model = OpenAIModel(
    'gemma3:27b-it-qat',
    provider=OpenAIProvider(
        base_url=os.getenv("AGENT_BASE_URL"),
        api_key=os.getenv("AGENT_API_KEY")
    ),
)

In [None]:
# MCP server
rag_mcp_command = "uv run python /agent_framework_demo/demo/1_rag_server.py stdio" #TODO: 要使用絕對路徑

rag_mcp_server = MCPServerStdio(
    command = rag_mcp_command.split(" ")[0],
    args = rag_mcp_command.split(" ")[1:],
    env = {
        "OPENAI_BASE_URL": os.getenv("OPENAI_BASE_URL"),
        "OPENAI_API_KEY": os.getenv("OPENAI_API_KEY")
    },
    timeout=60,  # 增加 timeout
)

In [51]:
# structured output
from pydantic import BaseModel, Field

class Reference(BaseModel):
    score: float
    #relevant: bool
    text: str

class ReferenceWithCriticism(BaseModel):
    relevant: bool
    text: str
    
class Criticism(BaseModel):
    passages: List[ReferenceWithCriticism]
    explanation: str
    sufficient: bool

class RAGFormat(BaseModel):
    reference: List[str] = Field(
        default_factory=list,
        description="List of passages text used to generate the answer."
    )
    final_answer: str


In [52]:
# Agent initialization
assistant_agent = Agent(
    name="assistant_agent",
    model = model,
    instructions = (
        "You are a helpful and intelligent assistant. The user you are helping speaks Traditional Chinese and comes from Taiwan, so in most cases, you should respond in Traditional Chinese. \n"
        "Behavior Rules: \n"
        "1. Direct Answering: If the question is clear and within your knowledge, answer directly.\n"
        "2. Clarification: If the question is vague or unclear, ask clarifying questions to understand the user’s intent before responding.\n"
        "3. Retrieval: If the question requires specialized or external knowledge, use the `rag` to obtain an answer based on relevant information retrieved from the database. Return the result with `RAGFormat` format.\n"
    ),
    output_type=Union[RAGFormat, str],

)
seeker_agent = Agent(
    name="seeker_agent",
    model=model,
    instructions=(
        "You are a helpful and intelligent assistant. The user you are helping speaks Traditional Chinese and comes from Taiwan, so in most cases, you should respond in Traditional Chinese. \n"
        "Behavior Rules: \n"
        "1. Retrieve relevant documents from the specific vector database. When retrieving from the database, the user's original intent should be preserved as much as possible, and the clarity of the question's meaning should be maintained.\n"
        "2. Return all the retrieval results with Reference format.\n"
    ),
    mcp_servers=[rag_mcp_server],
    output_type=List[Reference],
)

critic_agent = Agent(
    name="critic_agent",
    model=model, 
    instructions=(
        "You are a helpful and intelligent assistant. The user you are helping speaks Traditional Chinese and comes from Taiwan, so in most cases, you should respond in Traditional Chinese. \n"
        "Behavior Rules: \n"
        "1. Critically evaluate whether the retrieved passages are relevant to the user's query.\n"
        "2. Critically evaluate whether the retrieved passages are sufficient to answer the user's query.\n"
    ),
    output_type=Criticism,
)
generator_agent = Agent(
    name="generator_agent",
    model=model,
    instructions=(
        "You are a helpful and intelligent assistant. The user you are helping speaks Traditional Chinese and comes from Taiwan, so in most cases, you should respond in Traditional Chinese. \n"
        "Behavior Rules: \n"
        "1. Generate a final answer based on the retrieved passages.\n"
        "2. If the retrieved passages are insufficient, ask follow-up questions to gather more information.\n"
    ),
    output_type=RAGFormat,
)

In [53]:
# Memory Management

class MemoryManager(BaseModel):
    user_id: str = Field(
        default= f"user_{uuid.uuid4().hex[:16]}", 
        description = "User identification"
    )
    session_id: str = Field(
        default= f"{uuid.uuid4().hex[:16]}", 
        description = "Session identification"
    )
    max_turns: int = Field(
        default= 30, 
        description = "Max turns of conversation"
    )
    save_path: str = Field(
        default=os.getcwd(),
        description="Chat conversation history saving directory",
    )
    main_agent: str = Field(
        default="assistant_agent",
        description="Name of main agent",
    )
    message_history: Union[Dict[str, list[ModelMessage]], None] = Field(
        default={},
        description="Chat conversation history",
    )

    def convert_tool_role_into_assistant(self, massage_history):
        if isinstance(massage_history[-1].parts[0], ToolReturnPart):
            raw_assistant_message = massage_history[-2]
            assistant_message = ModelResponse(
                parts=[
                    TextPart(
                        content = raw_assistant_message.parts[0].args,
                        part_kind = "text",
                    ),
                ],
                usage = raw_assistant_message.usage,
                model_name = raw_assistant_message.model_name,
                timestamp = raw_assistant_message.timestamp,
                kind = raw_assistant_message.kind,
                vendor_id = raw_assistant_message.vendor_id
                )
            return massage_history + [assistant_message]
        else:
            return massage_history
    
    def trim_history_by_max_turns(self, message_history):
        turn_start_index = {}
        num_turns = 0
        for idx, message in enumerate(message_history):
            try:
                if isinstance(message.parts[0], UserPromptPart):
                    turn_start_index[num_turns] = idx
                    num_turns += 1
            except:
                pass
        if self.max_turns > num_turns:
            return message_history
        else:
            return message_history[turn_start_index[num_turns-self.max_turns]:]
    
    def append_message(self, agent_name: str, new_messages: List[ModelMessage]) -> None:
        """
            Add new message to the conversation history of specific agent
        """
        new_messages_converted = self.convert_tool_role_into_assistant(new_messages)
        if agent_name in self.message_history.keys():
            self.message_history[agent_name].extend(new_messages_converted)
        else:
            self.message_history[agent_name] = new_messages_converted
    
    def get_history(self, agent_name: str) -> Union[list[ModelMessage], None]:
        if agent_name in self.message_history.keys():
            return self.trim_history_by_max_turns(self.message_history[agent_name])
        else:
            print(f"No history of {agent_name}...")
            return None
    
    def save_history(self):
        history_as_python_objects = to_jsonable_python(self.get_history("assistant_agent"))

        saving_dir = os.path.join(self.save_path, self.user_id)
        os.makedirs(saving_dir, exist_ok=True)
        with open(os.path.join(saving_dir, f"{self.session_id}.json"), "w") as f:
            json.dump(history_as_python_objects, f, indent=4, ensure_ascii=False)
    
    def clear_message(self) -> None:
        """
            Reset chat history
        """
        self.message_history.clear()
    
    def load_history_from_disk(self):
        chat_history_dir = os.path.join(self.save_path, self.user_id, f"{self.session_id}.json")
        with open(chat_history_dir, "r") as f:
            history_as_python_objects = json.load(f)
        self.message_history[self.main_agent] = ModelMessagesTypeAdapter.validate_python(history_as_python_objects)


In [54]:
# multi-turn conversation
async def multi_turn_conversation(start_agent: Agent, history_manager: MemoryManager):
    while True:
        user_input = input("You: ")
        print(f"You: {user_input}")

        if user_input == "exit":
            print("Agent: Goodbye!")
            return response

        response = await start_agent.run(
            user_prompt = user_input,
            deps = history_manager,
            message_history = history_manager.get_history(start_agent.name),
        )
        print(f"Agent: {response.output}")
        history_manager.append_message(
            agent_name=start_agent.name,
            new_messages=response.new_messages(),
        )

In [55]:
async def retrieve_info(history_manager: MemoryManager, query: str):
    """
        Retrieve relevant information based on the user query.
    """
    n_retries = 3
    async with seeker_agent.run_mcp_servers():
        for _ in range(n_retries):
            print("start seeker agent...")
            response = await seeker_agent.run(
                user_prompt = query,
                deps = history_manager,
            )
            history_manager.append_message(
                agent_name=seeker_agent.name,
                new_messages=response.new_messages(),
            )
            try:
                if isinstance(response.output[0], Reference):
                    print("seeker agent:", response.output)
                    #print(response.usage())
                    #print(response.new_messages())
                    return response.output
            except Exception as e:
                print(f"Error in seeker agent response: {e}")

async def get_passage_criticism(history_manager: MemoryManager, user_query: str, retrieved_passages: List[Reference]):
    """
        Get criticism for each retrieved passage.
    """
    input_text = (
        "The user's query and the retrieved passages are as follows:\n"
        f"## User Query: {user_query}\n"
        "## Retrieved Passages:\n"
        + "\n".join([f"[{i + 1}] {passage.text}" for i, passage in enumerate(retrieved_passages)])
    )
    n_retries = 3
    for _ in range(n_retries):
        response = await critic_agent.run(
            user_prompt=input_text,
            deps=history_manager,
        )
        history_manager.append_message(
            agent_name=critic_agent.name,
            new_messages=response.new_messages(),
        )
        if isinstance(response.output, Criticism):
            print("critic agent:", response.output)
            return response.output


async def generate_answer(history_manager: MemoryManager, user_query: str, passages_with_criticism: Criticism):
    """
        Generate final answer to the user query based on retrieved passages.
    """
    input_text = (
        "The user's query and the retrieved passages are as follows:\n"
        f"## User Query: {user_query}\n"
        "## Relevant Passages:\n"
        + "\n".join([f"[{i+1}] {passage.text}" for i, passage in enumerate(passages_with_criticism.passages) if passage.relevant])
    )
    n_retries = 3
    for _ in range(n_retries):
        response = await generator_agent.run(
            user_prompt=input_text,
            deps=history_manager,
        )
        history_manager.append_message(
            agent_name=generator_agent.name,
            new_messages=response.new_messages(),
        )
        if isinstance(response.output, RAGFormat):
            print("generator agent:", response.output)
            return response.output
    

@assistant_agent.tool()
async def rag(ctx: RunContext[MemoryManager], query: str) -> Union[RAGFormat ,Criticism, None]:
    """
        response the user query based on external knowledge
    """
    retrieved_passages = await retrieve_info(ctx.deps, query)
    critic_result = await get_passage_criticism(ctx.deps, ctx.prompt, retrieved_passages)
    if critic_result.sufficient:
        final_answer = await generate_answer(ctx.deps, ctx.prompt, critic_result)
        return final_answer
    else:
        print("Insufficient criticism from the critic agent.")
        return critic_result

In [56]:
async def main(history_manager: MemoryManager):
    response = await multi_turn_conversation(assistant_agent, history_manager)
    return response

In [None]:
history_manager = MemoryManager(
    user_id="user_1", 
    session_id="session_1",
    save_path=os.path.join(os.getcwd(), "pydanticai_storage")
)

result = await main(history_manager)

You: 我想要請假，有沒有什麼限制？
No history of assistant_agent...
start seeker agent...
seeker agent: [Reference(score=0.7, text='此為依據勞動基準法所訂定的最基準工作時間。\n一、每週工作時數不得超過四十小時。\n二、每日工作時數為八小時。\n三、每日工作時間為\n08:30 開始上班\n12:00 午休及午餐時間\n13:00 下午工作時間開始\n15:00 下午休息時間\n15:15 繼續工作\n17:30 下班\n子女未滿一歲須員工親自哺乳者，除規定之休息時間外，本公司將每日另給哺乳時間二次，每次以三十分鐘為度，哺乳時間，視為工作時間。\n員工為撫育未滿三歲子女，得請求下列所定事項之一：\n1.\t每天減少工作時間一小時；減少之工作時間，不得請求報酬。\n2.\t調整工作時間。\n員工為前二項哺乳時間、減少或調整工時之請求時，本公司不得拒絕或視為缺勤而影響其全勤獎金、考績或為其他不利之處分。'), Reference(score=0.42953295, text='一、薪資採月薪制，破月者薪資以月薪除以30算出日薪資再乘實際工作天數(含工作期間的例假日)。時薪則以日薪除以8計算得之。\n二、薪資項目主要區分為二類\n基本薪資：依同仁所任職務之不同所給予之工作報酬稱之。\n主管加給：付與任管理職的同仁的薪資加給稱之。\n薪資與職務對照表另行訂定，並依當時之物價指數不同而調整。'), Reference(score=0.3, text='一、同仁因執行公務中受傷之病假，稱為公傷假。\n二、公傷假之醫療費用，以勞工保險及公司加保之團保負擔之，其餘相關規定及補償辦法比照第八章災害傷病補償及撫卹之規定辦理。\n三、職業災害未認定前，勞工得先請普通傷病假；普通傷病假期滿，得申請留職停薪，如認定結果為職業災害，再以公傷病假處理。'), Reference(score=0.28747067, text='一、紀念日，勞動節日及中央主管機關規定應放假之日，是放假日。 \n二、相關日期依當年政府規定於每年開始時，公告同仁周知。\n三、主管或同仁個人因工作上的需要必須應用此放假日工作者，必須依照延長工作時間申請程序申請延長工作。\n四、申請延長工作時間作業必須在規定期限內完成。\

In [None]:
history_manager.save_history()

In [45]:
history_manager.message_history

{'seeker_agent': [ModelRequest(parts=[UserPromptPart(content='請假規定', timestamp=datetime.datetime(2025, 7, 18, 9, 1, 35, 174590, tzinfo=datetime.timezone.utc))], instructions="You are a helpful and intelligent assistant. The user you are helping speaks Traditional Chinese and comes from Taiwan, so in most cases, you should respond in Traditional Chinese. \nBehavior Rules: \n1. Retrieve relevant documents from the specific vector database. When retrieving from the database, the user's original intent should be preserved as much as possible, and the clarity of the question's meaning should be maintained.\n2. Return all the retrieval results with Reference format."),
  ModelResponse(parts=[ToolCallPart(tool_name='retrieve_fps_rules_db', args='{"query":"請假規定"}', tool_call_id='ujgwlhUNx6INgh81hozB2bPlgRwy5AHc')], usage=Usage(requests=1, request_tokens=982, response_tokens=48, total_tokens=1030, details={}), model_name='gemma3:27b-it-qat', timestamp=datetime.datetime(2025, 7, 18, 9, 1, 37, tz