"""
Agent 协调器
根据用户的自然语言输入，智能路由到对应的 Agent
使用 OpenAI Agents 框架的 handoff 模式
"""
from agents import Agent, Runner, function_tool, RunConfig, ModelSettings, SQLiteSession
from agents.models.openai_chatcompletions import OpenAIChatCompletionsModel
from openai import OpenAI, AsyncOpenAI
from typing import Dict, Any, Optional
import json
import os
import asyncio

# 导入各个 Agent
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent))

# 加载环境变量
from dotenv import load_dotenv
load_dotenv()
# Enable auto tracing for OpenAI Agents SDK
import mlflow
mlflow.openai.autolog()

# Optional: Set a tracking URI and an experiment
mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment("OpenAI Agent")

# 提示词文件路径
PROMPT_FILE = Path(__file__).parent.parent / "prompts" / "agent_coordinator.md"

# 导入工具函数
from tools.schedule_tools import (
    create_schedule as _create_schedule,
    delete_schedule as _delete_schedule,
    update_schedule as _update_schedule,
    query_schedules as _query_schedules,
    set_current_session_id as set_schedule_session_id
)
from tools.wechat_tools import (
    send_wechat_message as _send_wechat_message,
    query_message_history as _query_wechat_messages,
    set_current_session_id as set_wechat_session_id
)
from config.settings import LLM_API_KEY, LLM_BASE_URL, LLM_MODEL


class AgentCoordinator:
    """Agent 协调器 - 使用 OpenAI Agents 框架的 handoff 模式"""
    
    # 全局 Session 缓存（支持多轮对话）
    # key: session_id, value: SQLiteSession 对象
    _session_cache: Dict[str, SQLiteSession] = {}
    
    @staticmethod
    def _load_system_prompt() -> str:
        """从文件加载系统提示词"""
        try:
            if PROMPT_FILE.exists():
                with open(PROMPT_FILE, 'r', encoding='utf-8') as f:
                    return f.read()
            else:
                print(f"⚠️ 提示词文件不存在: {PROMPT_FILE}")
                return "你是一个智能助手协调器，负责将用户请求路由到合适的专业Agent。"
        except Exception as e:
            print(f"❌ 加载提示词失败: {e}")
            return "你是一个智能助手协调器，负责将用户请求路由到合适的专业Agent。"
    
    def __init__(self, api_key: str = None, model: str = None, base_url: str = None):
        """
        初始化 Agent 协调器
        
        Args:
            api_key: API Key，如果不提供则从配置文件读取
            model: 使用的模型名称
            base_url: API 基础 URL，如果不提供则从配置文件读取
        """
        # 使用传入的参数，如果没有则使用配置文件中的值
        api_key = api_key or LLM_API_KEY
        base_url = base_url or LLM_BASE_URL or "https://dashscope.aliyuncs.com/compatible-mode/v1"
        model = model or LLM_MODEL or "qwen-max"
        
        # 构建客户端参数（使用同步客户端）
        client_kwargs = {
            "api_key": api_key,
            "base_url": base_url
        }
        
        self.client = OpenAI(**client_kwargs)
        self.model = model
        
        print(f"\n{'='*60}")
        print(f"🤖 Agent 协调器初始化")
        print(f"{'='*60}")
        print(f"📍 API 地址: {base_url}")
        print(f"🧠 模型名称: {model}")
        print(f"{'='*60}\n")
        
        # 使用 function_tool 装饰器包装日程管理工具
        @function_tool
        def create_schedule(title: str, start_time: str, end_time: str, 
                          description: str = "", location: str = "", 
                          participants: list = None) -> str:
            """创建一个新的日程安排
            
            Args:
                title: 日程标题
                start_time: 开始时间，格式: YYYY-MM-DD HH:MM
                end_time: 结束时间，格式: YYYY-MM-DD HH:MM
                description: 日程描述
                location: 地点
                participants: 参与人列表
            """
            print(f"\n🔧 [工具调用] create_schedule")
            print(f"   参数: title={title}, start_time={start_time}, end_time={end_time}")
            result = _create_schedule(title, start_time, end_time, description, location, participants or [])
            print(f"   结果: {result}")
            return result
        
        @function_tool
        def delete_schedule(schedule_id: int) -> str:
            """删除指定的日程
            
            Args:
                schedule_id: 日程ID
            """
            return _delete_schedule(schedule_id)
        
        @function_tool
        def update_schedule(schedule_id: int, title: str = None, start_time: str = None, 
                          end_time: str = None, description: str = None, 
                          location: str = None, participants: list = None) -> str:
            """更新指定的日程
            
            Args:
                schedule_id: 日程ID
                title: 新的标题
                start_time: 新的开始时间
                end_time: 新的结束时间
                description: 新的描述
                location: 新的地点
                participants: 新的参与人列表
            """
            return _update_schedule(schedule_id, title, start_time, end_time, description, location, participants)
        
        @function_tool
        def query_schedules(start_date: str = None, end_date: str = None, keyword: str = None) -> str:
            """查询日程列表
            
            Args:
                start_date: 开始日期，格式: YYYY-MM-DD
                end_date: 结束日期，格式: YYYY-MM-DD
                keyword: 关键词（搜索标题和描述）
            """
            return _query_schedules(start_date, end_date, keyword)
        
        # 使用 function_tool 装饰器包装微信工具
        @function_tool
        def send_wechat_message(recipient: str, message: str, scheduled_time: str = None) -> str:
            """发送消息
            
            Args:
                recipient: 收信人
                message: 消息内容
                scheduled_time: 定时发送时间，格式: YYYY-MM-DD HH:MM（可选）
            """
            return _send_wechat_message(recipient, message, scheduled_time)
        
        @function_tool
        def query_wechat_messages(recipient: str = None, start_date: str = None, end_date: str = None) -> str:
            """查询消息历史
            
            Args:
                recipient: 收信人（可选）
                start_date: 开始日期，格式: YYYY-MM-DD（可选）
                end_date: 结束日期，格式: YYYY-MM-DD（可选）
            """
            return _query_wechat_messages(recipient, start_date, end_date)
        
        # 加载日程管理提示词
        schedule_prompt_file = Path(__file__).parent.parent / "prompts" / "schedule_agent.md"
        try:
            with open(schedule_prompt_file, 'r', encoding='utf-8') as f:
                schedule_instructions = f.read()
        except:
            schedule_instructions = "你是一个专业的日程管理助手，负责帮助用户创建、查询、更新和删除日程安排。"
        
        # 加载微信助手提示词
        wechat_prompt_file = Path(__file__).parent.parent / "prompts" / "wechat_agent.md"
        try:
            with open(wechat_prompt_file, 'r', encoding='utf-8') as f:
                wechat_instructions = f.read()
        except:
            wechat_instructions = "你是一个专业的微信消息助手，负责帮助用户发送微信消息、查询消息历史。"
        
        # 创建 AsyncOpenAI 客户端（兼容阿里云等第三方 API）
        async_client = AsyncOpenAI(
            base_url=base_url,
            api_key=api_key
        )
        
        # 创建 OpenAIChatCompletionsModel（使用标准的 chat/completions 端点）
        chat_model = OpenAIChatCompletionsModel(
            model=self.model,
            openai_client=async_client
        )
        
        # 配置模型参数
        model_settings = ModelSettings()
        model_settings.temperature = 0.7
        
        # 准备额外的 API 参数（处理 DashScope 等特定平台的要求）
        if "dashscope" in base_url.lower():
            # DashScope 要求非流式调用时 enable_thinking 必须为 false
            # 使用 extra_body 而不是 extra_args，避免 OpenAI SDK 参数校验错误
            model_settings.extra_body = {"enable_thinking": False}
            print("🔧 检测到 DashScope，添加 enable_thinking=False 参数")
        
        # 创建日程管理 Agent（使用英文名称避免中文字符警告）
        self.schedule_agent = Agent(
            name="schedule_agent",
            instructions=schedule_instructions,
            tools=[create_schedule, delete_schedule, update_schedule, query_schedules],
            model=chat_model,
            model_settings=model_settings,
            handoff_description= "用于日程管理的助手"
        )
        
        # 创建微信消息 Agent
        self.wechat_agent = Agent(
            name="wechat_agent",
            instructions=wechat_instructions,
            tools=[send_wechat_message, query_wechat_messages],
            model=chat_model,
            model_settings=model_settings,
            handoff_description= "用于处理微信消息的助手"
        )
        
        # 加载路由 Agent 的系统提示词
        system_prompt = self._load_system_prompt()
        
        # 创建路由 Agent（Triage Agent）- 使用 handoffs 模式
        self.triage_agent = Agent(
            name="triage_agent",
            instructions=system_prompt,
            handoffs=[self.schedule_agent, self.wechat_agent],
            model=chat_model,
            model_settings=model_settings
        )
        
        # 创建 RunConfig（禁用 tracing）
        self.run_config = RunConfig(
            tracing_disabled=True,
        )
        
        # 创建 Runner
        self.runner = Runner()

    
    def process_stream(self, user_input: str, current_time: str = None, session_id: str = None):
        """
        流式处理用户输入，每执行完一个工具就立即返回结果
        
        Args:
            user_input: 用户的自然语言输入
            current_time: 当前时间字符串
            session_id: 会话 ID，用于多轮对话
            
        Yields:
            包含处理步骤信息的字典，类型包括：
            - {"type": "routing", "agent": agent_name} - 路由信息
            - {"type": "tool_call", "name": tool_name, "arguments": args} - 工具调用开始
            - {"type": "tool_result", "name": tool_name, "arguments": args, "result": result} - 工具执行完成
            - {"type": "final", "result": final_response} - 最终响应
            - {"type": "error", "message": error_message} - 错误信息
        """
        try:
            print(f"\n🔄 [流式] 开始处理请求: {user_input}")
            print(f"🤖 使用模型: {self.model}")
            
            # 创建或获取 Session
            session = None
            if session_id:
                # 从缓存获取或创建新的 Session（支持多轮对话）
                if session_id not in self._session_cache:
                    self._session_cache[session_id] = SQLiteSession(
                        session_id=session_id,
                        db_path=":memory:"  # 内存数据库
                    )
                    print(f"📊 [流式] 创建新的内存 Session: {session_id}")
                else:
                    print(f"📊 [流式] 复用已有 Session: {session_id}")
                
                session = self._session_cache[session_id]
                
                # 设置当前会话ID到工具上下文（用于日程和消息隔离）
                set_schedule_session_id(session_id)
                set_wechat_session_id(session_id)
            else:
                # 使用默认会话
                set_schedule_session_id("default")
                set_wechat_session_id("default")
            
            # 添加当前时间上下文
            actual_input = user_input
            if current_time:
                actual_input = f"[当前时间: {current_time}]\n\n{user_input}"
                print(f"⏰ 添加时间上下文: {current_time}")
            
            # 定义异步生成器来处理流式结果
            async def run_agent_stream(original_input, max_retries=3):
                """运行 Agent 并流式返回工具执行信息（支持重试）"""
                from agents.exceptions import ModelBehaviorError
                
                # 用于跟踪工具调用和结果
                tool_calls = {}  # key: tool_id, value: {name, arguments}
                routed_agent = None
                seen_agents = set()  # 记录已经返回过的 agent，避免重复
                last_error_msg = ""  # 保存上次错误信息
                
                # 重试循环
                for attempt in range(max_retries):
                    # 每次重试使用的输入（只在重试时添加错误反馈）
                    current_input = original_input
                    current_session = session  # 默认使用正常 Session
                    
                    # 如果不是第一次尝试，添加错误反馈，并使用临时内存 Session
                    if attempt > 0 and last_error_msg:
                        # 提取上次错误的工具名称
                        import re
                        tool_match = re.search(r'Tool (\w+) not found', last_error_msg)
                        tool_name = tool_match.group(1) if tool_match else "unknown"
                        
                        # 构建临时错误反馈（只用于本次重试）
                        feedback = f"""
🚨 系统错误提示 🚨

你刚才尝试直接调用工具 "{tool_name}"，但这是错误的！

❌ 你的角色：triage_agent（路由助手）
❌ 你没有任何工具权限

✅ 你唯一能做的事：使用 handoff 转交

📝 正确示例：
用户："查询我这周的日程"
你的响应：{{使用 handoff 转交到 schedule_agent}}

⚠️ 现在请重新处理下面的用户请求，这次务必使用 handoff，不要调用任何工具！
"""
                        print(f"🔄 [流式] 第 {attempt + 1} 次尝试：添加错误反馈并重试（使用临时 Session）...")
                        # 临时添加反馈，只用于本次 run 调用
                        current_input = f"{feedback}\n\n【用户请求】：{user_input}"
                        # 重试时使用临时内存 Session（不污染主 Session，但能记录工具调用）
                        current_session = SQLiteSession(
                            session_id=f"temp_retry_{attempt}",
                            db_path=":memory:"  # 使用内存数据库，不持久化
                        )
                    
                    # 使用真正的流式 API：run_streamed（注意：这是同步方法，不需要 await）
                    result_streaming = self.runner.run_streamed(
                        starting_agent=self.triage_agent,
                        input=current_input,
                        run_config=self.run_config,
                        session=current_session,
                        max_turns=25
                    )
                    print(f"✅ [流式] 第 {attempt + 1} 次尝试启动成功，开始实时流式处理...")
                    
                    # 执行成功，使用真正的流式 API 处理事件
                    try:
                        print(f"\n📊 [流式] 开始真正的实时流式处理...")
                        
                        # 使用 async for 实时处理流式事件
                        async for event in result_streaming.stream_events():
                            event_type = event.type
                            
                            # 1. Agent 更新事件 - 表示路由到新的 Agent
                            if event_type == "agent_updated_stream_event":
                                new_agent_name = event.new_agent.name
                                if new_agent_name != "triage_agent" and new_agent_name not in seen_agents:
                                    seen_agents.add(new_agent_name)
                                    routed_agent = new_agent_name
                                    print(f"  🔀 [实时流式] Agent 路由: {new_agent_name}")
                                    yield {
                                        "type": "routing",
                                        "agent": new_agent_name
                                    }
                            
                            # 2. RunItem 事件 - 包含工具调用、工具结果、消息等
                            elif event_type == "run_item_stream_event":
                                event_name = event.name
                                item = event.item
                                
                                # 工具调用事件
                                if event_name == "tool_called":
                                    if hasattr(item, 'raw_item'):
                                        raw = item.raw_item
                                        tool_id = getattr(raw, 'id', None) or getattr(raw, 'call_id', None)
                                        tool_name = getattr(raw, 'name', None)
                                        tool_args = getattr(raw, 'arguments', None)
                                        
                                        if tool_name:
                                            # 过滤掉 handoff 函数（transfer_to_* 开头的都是 Agent 切换，不是真正的工具）
                                            if tool_name.startswith('transfer_to_'):
                                                print(f"  🔀 [实时流式] 跳过 handoff 函数: {tool_name}")
                                                continue
                                            
                                            # 解析参数
                                            if isinstance(tool_args, str):
                                                try:
                                                    tool_args = json.loads(tool_args)
                                                except:
                                                    pass
                                            
                                            # 保存工具调用信息
                                            if tool_id:
                                                tool_calls[tool_id] = {
                                                    "name": tool_name,
                                                    "arguments": tool_args
                                                }
                                            
                                            print(f"  🔧 [实时流式] 工具调用: {tool_name}")
                                            yield {
                                                "type": "tool_call",
                                                "name": tool_name,
                                                "arguments": tool_args
                                            }
                                
                                # 工具结果事件
                                elif event_name == "tool_output":
                                    if hasattr(item, 'output'):
                                        output = item.output
                                        
                                        # 尝试找到对应的工具调用 ID
                                        result_tool_id = None
                                        if hasattr(item, 'raw_item'):
                                            raw = item.raw_item
                                            if isinstance(raw, dict):
                                                result_tool_id = raw.get('tool_call_id')
                                            elif hasattr(raw, 'tool_call_id'):
                                                result_tool_id = raw.tool_call_id
                                        
                                        # 查找对应的工具调用
                                        tool_info = tool_calls.get(result_tool_id)
                                        if not tool_info and tool_calls:
                                            # 如果找不到，使用最近的一个
                                            tool_info = list(tool_calls.values())[-1]
                                        
                                        if tool_info:
                                            # 过滤掉 handoff 函数的结果
                                            if tool_info['name'].startswith('transfer_to_'):
                                                print(f"  🔀 [实时流式] 跳过 handoff 结果: {tool_info['name']}")
                                                continue
                                            
                                            print(f"  ✅ [实时流式] 工具结果: {tool_info['name']}")
                                            yield {
                                                "type": "tool_result",
                                                "name": tool_info["name"],
                                                "arguments": tool_info["arguments"],
                                                "result": output
                                            }
                                
                                # 消息输出事件 - 可能是最终响应的一部分
                                elif event_name == "message_output_created":
                                    # 暂不处理，等待流式完成后再返回最终响应
                                    pass
                        
                        # 流式完成，返回最终响应
                        final_response = ""
                        if hasattr(result_streaming, 'final_output'):
                            final_response = str(result_streaming.final_output)
                        
                        print(f"  💬 [实时流式] 最终响应")
                        yield {
                            "type": "final",
                            "result": final_response or "处理完成",
                            "routed_agent": routed_agent
                        }
                        
                        # 成功完成，跳出重试循环
                        break
                        
                    except Exception as e:
                        print(f"  ❌ [流式] 错误: {str(e)}")
                        yield {
                            "type": "error",
                            "message": f"处理失败: {str(e)}"
                        }
                        break
            
            # 创建事件循环并真正流式处理
            loop = asyncio.new_event_loop()
            asyncio.set_event_loop(loop)
            try:
                # 创建异步生成器
                async_gen = run_agent_stream(actual_input)
                
                # 使用 while True 循环实时获取每个事件
                while True:
                    try:
                        # 获取下一个事件
                        event = loop.run_until_complete(async_gen.__anext__())
                        # 立即 yield，实现真正的流式返回
                        yield event
                    except StopAsyncIteration:
                        # 异步生成器结束
                        break
                    
            finally:
                loop.close()
            
            print(f"✅ [流式] 处理完成")
            
        except Exception as e:
            error_msg = f"流式处理失败: {str(e)}"
            print(f"❌ {error_msg}")
            import traceback
            traceback.print_exc()
            yield {
                "type": "error",
                "message": error_msg
            }
    
    def run(self):
        """运行交互式协调器（使用流式处理）"""
        from datetime import datetime
        
        print("=" * 60)
        print("🤖 智能助手协调器已启动（流式模式）")
        print("=" * 60)
        print("我可以帮你：")
        print("  📅 管理日程：创建、查询、修改、删除日程安排")
        print("  💬 发送微信：给好友发送微信消息，支持定时发送")
        print("\n试试这些：")
        print("  - '帮我创建明天下午2点的项目会议'")
        print("  - '发消息给张三说明天见'")
        print("  - '查询我本周的所有日程'")
        print("  - '明天上午10点提醒李四开会'")
        print("\n输入 'quit' 退出")
        print("=" * 60)
        
        session_id = "cli_session"  # 使用固定的会话ID支持多轮对话
        
        # 固定用户输入
        user_input = "发消息给张三明天见"
        print(f"\n👤 您: {user_input}")
        
        try:
            current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            print(f"\n{'=' * 60}")
            print(f"🔄 开始处理...")
            
            final_result = ""
            routed_agent = None
            
            # 流式处理用户输入
            for event in self.process_stream(user_input, current_time=current_time, session_id=session_id):
                event_type = event.get("type")
                
                if event_type == "routing":
                    agent_name = event.get("agent", "unknown")
                    agent_display = {
                        "schedule_agent": "📅 日程管理助手",
                        "wechat_agent": "💬 微信消息助手"
                    }.get(agent_name, agent_name)
                    print(f"🔀 路由到: {agent_display}")
                    routed_agent = agent_name
                    
                elif event_type == "tool_call":
                    tool_name = event.get("name", "unknown")
                    tool_args = event.get("arguments", {})
                    print(f"🔧 调用工具: {tool_name}")
                    print(f"   参数: {tool_args}")
                    
                elif event_type == "tool_result":
                    tool_name = event.get("name", "unknown")
                    result = event.get("result", "")
                    print(f"✅ 工具执行完成: {tool_name}")
                    print(f"   结果: {result}")
                    
                elif event_type == "final":
                    final_result = event.get("result", "")
                    
                elif event_type == "error":
                    error_msg = event.get("message", "未知错误")
                    print(f"❌ 错误: {error_msg}")
            
            # 显示最终结果
            print(f"\n{'=' * 60}")
            print(f"🎯 最终结果:")
            print(f"{final_result}")
            print(f"{'=' * 60}")
            
        except Exception as e:
            print(f"\n❌ 错误: {str(e)}")
            import traceback
            traceback.print_exc()


if __name__ == "__main__":
    # 测试代码
    coordinator = AgentCoordinator()
    coordinator.run()
