# hooks.py (V17 - Simplified WebSearch via web_search_options for Anthropic, No Query Extraction, Reduced Log Verbosity, UnboundLocalError Fix)

import litellm
import traceback
import copy # For deep copying messages before token counting
# import re # No longer needed for query extraction - [鲁棒性/兼容性] 标记为不再需要的旧功能
from litellm.integrations.custom_logger import CustomLogger

# --- Configuration ---
GEMINI_CACHE_MINIMUMS = {
    "gemini-1.5-flash": 1024,
    "gemini-1.5-pro": 1024,
    "gemini/gemini-2.5-flash-latest": 2048,
    "gemini-1.5-flash-latest": 2048,
    "gemini/gemini-2.5-flash-lite": 2048, # Explicitly add for known mapping
    # Note: For other models, DEFAULT_MIN_CACHE_TOKENS will be used if not found here.
}
DEFAULT_MIN_CACHE_TOKENS = 2048 # Increased to 2048 based on your 400 error message

# --- Proposed Global Variables for Configurability ---
# [改进点 3] 将 Claude 模型映射提炼为全局变量
CLAUDE_TO_GEMINI_MODEL_FOR_TOKEN_COUNTING = {
    "claude-3-sonnet": "gemini/gemini-2.5-flash-lite",
    "claude-3-haiku": "gemini/gemini-2.5-flash-lite",
    "claude-3-opus": "gemini/gemini-2.5-flash-lite", # Example for Opus, adjust as needed
    # 通用回退或默认映射，如果更具体的模型别名未匹配
    "claude": "gemini/gemini-2.5-flash-lite", 
}

# [改进点 4] WebSearch 工具检测的全局变量，提高灵活性
WEB_SEARCH_TOOL_NAME_IDENTIFIER = "web_search"
# 允许的 web_search 工具类型前缀，例如 'web_search_20250305' 可以通过 'web_search_' 匹配
WEB_SEARCH_TOOL_TYPE_PREFIXES = ["web_search_"] 
# 允许的精确匹配的 web_search 工具类型，可以保留以兼容特定旧版本
WEB_SEARCH_TOOL_TYPES_EXACT = ["web_search_20250305"] 


class GeminiCacheManager(CustomLogger):
    def __init__(self):
        super().__init__()
        print("\n--- [GeminiCacheManager] Initialized with super robust token counting (V17) ---")

    async def async_pre_call_hook(self, user_api_key_dict, cache, data, call_type):
        print(f"\n--- [HOOK V17 ACTIVATED] Pre-request callback running with call_type: '{call_type}'. ---")

        try:
            completion_call_types = {"completion", "acompletion", "anthropic_messages"}
            if call_type not in completion_call_types:
                print(f"--- [HOOK SKIPPED] Call type '{call_type}' not a completion type. Passing through.\n")
                return data # [鲁棒性/过滤] 仅处理特定调用类型

            request_model_alias = data.get("model", "")
            
            is_gemini_model = False
            model_id_for_token_counting = request_model_alias # Default to alias

            lower_request_model_alias = request_model_alias.lower()

            # [改进点 3] 使用全局变量进行 Claude 模型映射
            if "claude" in lower_request_model_alias:
                is_gemini_model = True
                print(f"--- [HOOK HEURISTIC] 'claude' alias detected. Attempting to map to Gemini model for token counting.")
                model_id_for_token_counting = CLAUDE_TO_GEMINI_MODEL_FOR_TOKEN_COUNTING.get(
                    lower_request_model_alias, 
                    CLAUDE_TO_GEMINI_MODEL_FOR_TOKEN_COUNTING.get("claude", "gemini/gemini-2.5-flash-lite") # 如果没有精确匹配，尝试通用 'claude' 映射，否则用默认
                )
                print(f"--- [HOOK HEURISTIC] Mapped Claude alias '{request_model_alias}' to token counting model: '{model_id_for_token_counting}'.")
            elif "gemini" in lower_request_model_alias:
                is_gemini_model = True
                print(f"--- [HOOK HEURISTIC] 'gemini' alias detected. Assuming it's a Gemini model.")
            else:
                print(f"--- [HOOK SKIPPED] Model alias '{request_model_alias}' does not indicate a Gemini model. Passing through.\n")
                return data # [鲁棒性/过滤] 仅处理 Gemini 或 Claude 模型

            print(f"--- [HOOK ENGAGED] Processing Gemini model '{model_id_for_token_counting}' (request alias: '{request_model_alias}') for caching and web_search_options translation. ---")

            # --- START: WebSearch Tool to web_search_options Translation Logic ---
            activate_web_search_options = False
            
            # Step 1: Check for web_search tool declarations in the incoming 'tools' list
            if 'tools' in data and isinstance(data['tools'], list):
                modified_tools = []
                # First pass to detect 'web_search' and remove its declaration
                for tool_item in data['tools']:
                    is_web_search_tool_declaration = False
                    
                    # [改进点 4] 使用全局变量和更灵活的匹配检查 web_search 工具
                    if (tool_item.get('name') == WEB_SEARCH_TOOL_NAME_IDENTIFIER and
                        (tool_item.get('type') in WEB_SEARCH_TOOL_TYPES_EXACT or
                         any(tool_item.get('type', '').startswith(prefix) for prefix in WEB_SEARCH_TOOL_TYPE_PREFIXES))):
                        is_web_search_tool_declaration = True
                        activate_web_search_options = True
                    # Check for Anthropic-style tool declaration (still focuses on name, input_schema)
                    elif tool_item.get('name') == WEB_SEARCH_TOOL_NAME_IDENTIFIER and 'input_schema' in tool_item:
                        is_web_search_tool_declaration = True
                        activate_web_search_options = True
                    # Check for Gemini-style tool wrapper containing 'web_search'
                    elif 'function_declarations' in tool_item and isinstance(tool_item['function_declarations'], list):
                        new_function_declarations = []
                        contains_web_search_in_wrapper = False
                        for func_decl in tool_item['function_declarations']:
                            if func_decl.get('name') == WEB_SEARCH_TOOL_NAME_IDENTIFIER: # 使用全局变量
                                contains_web_search_in_wrapper = True
                                activate_web_search_options = True
                                continue # This specific func_decl will be removed
                            new_function_declarations.append(func_decl)
                        
                        if contains_web_search_in_wrapper:
                            if new_function_declarations:
                                tool_item['function_declarations'] = new_function_declarations
                                modified_tools.append(tool_item)
                            continue # Don't add an empty wrapper - [鲁棒性] 避免添加空的工具包装器
                        else: # If web_search was not in this wrapper, add the original tool_item
                            modified_tools.append(tool_item)
                        continue # Done processing this wrapper
                    
                    if not is_web_search_tool_declaration: # If it's not a web_search declaration, keep it - [鲁棒性] 只保留非 web_search 工具
                        modified_tools.append(tool_item)
                    else:
                        print(f"--- [HOOK WEB_SEARCH] Removed '{WEB_SEARCH_TOOL_NAME_IDENTIFIER}' tool declaration: {tool_item.get('name', 'N/A')}")

                data['tools'] = modified_tools
                
                # 精简工具日志输出 (after initial removal of web_search declarations) - [鲁棒性/调试] 详细日志，便于理解处理过程
                tools_summary_after_removal = []
                for tool_item in data['tools']:
                    if isinstance(tool_item, dict):
                        if 'name' in tool_item and 'input_schema' in tool_item and isinstance(tool_item['input_schema'], dict) and 'type' in tool_item['input_schema']:
                            tools_summary_after_removal.append({
                                'name': tool_item['name'],
                                'input_schema_type': tool_item['input_schema']['type']
                            })
                        elif 'type' in tool_item and 'name' in tool_item:
                            tools_summary_after_removal.append({
                                'name': tool_item['name'],
                                'type': tool_item['type']
                            })
                        elif 'function_declarations' in tool_item and isinstance(tool_item['function_declarations'], list):
                            inner_funcs = []
                            for fd in tool_item['function_declarations']:
                                if 'name' in fd:
                                    inner_funcs.append({'name': fd['name'], 'parameters_type': fd.get('parameters', {}).get('type', 'N/A')})
                            if inner_funcs:
                                tools_summary_after_removal.append({'function_declarations_wrapper': inner_funcs})
                        else:
                            tools_summary_after_removal.append({'tool_item_unparsable': tool_item})
                    else:
                        tools_summary_after_removal.append({'non_dict_tool_item': tool_item})
                print(f"--- [HOOK TOOL] Tools (summary) after removing web_search declarations: {tools_summary_after_removal}")

            # Step 2: If web_search_options should be activated, inject it
            if activate_web_search_options:
                print(f"--- [HOOK WEB_SEARCH] '{WEB_SEARCH_TOOL_NAME_IDENTIFIER}' tool declaration was found. Activating web_search_options.")
                # According to docs, only search_context_size is needed, no explicit 'query' for Anthropic's web_search_options
                if 'web_search_options' not in data:
                    data['web_search_options'] = {} # [鲁棒性] 确保字典存在，防止 KeyError
                # Optionally set search_context_size, e.g., 'small', 'medium', 'large', 'omniscience'
                # data['web_search_options']['search_context_size'] = 'large' 
                print(f"--- [HOOK WEB_SEARCH] Final web_search_options injected: {data['web_search_options']}")
            else:
                print("--- [HOOK WEB_SEARCH] No 'web_search' tool declaration found. Not injecting web_search_options.")
            
            # --- END: WebSearch Tool to web_search_options Translation Logic ---

            # The rest of the hook (token counting, caching disablement, error handling) remains the same
            messages = data.get("messages")
            if not messages:
                print("--- [HOOK SKIPPED] No 'messages' found in request data. Passing through.\n")
                return data # [鲁棒性/过滤] 如果没有消息，则跳过
            
            token_count = 0 
            token_counting_successful = False

            try:
                temp_messages_for_token_count = []
                for message in messages:
                    temp_message = copy.deepcopy(message) # [鲁棒性] 深拷贝消息，防止副作用

                    # 以下大量代码都是为了清理 `content` 和 `system` 中的 `cache_control`
                    # 以及处理 None/空值，以确保 token_counter 的准确性和避免 API 错误。
                    # 这都是为了鲁棒性，应对各种可能的输入格式。
                    if "content" in temp_message:
                        content_value = temp_message["content"]
                        if content_value is None: # [鲁棒性] 处理 None 值
                            temp_message["content"] = []
                        elif isinstance(content_value, list):
                            new_content_list = []
                            for item in content_value:
                                if isinstance(item, dict) and "cache_control" in item:
                                    new_item = {k: v for k, v in item.items() if k != "cache_control"}
                                    if new_item: # [鲁棒性] 避免添加空字典
                                        new_content_list.append(new_item)
                                else:
                                    new_content_list.append(item)
                            temp_message["content"] = new_content_list
                        elif isinstance(content_value, dict):
                            new_content_dict = {k: v for k, v in content_value.items() if k != "cache_control"}
                            if "text" in new_content_dict or new_content_dict: # [鲁棒性] 确保不是空的 content
                                temp_message["content"] = new_content_dict
                            else:
                                temp_message["content"] = []
                    else:
                        temp_message["content"] = []

                    if "system" in temp_message:
                        system_value = temp_message["system"]
                        if system_value is None: # [鲁棒性] 处理 None 值
                            temp_message["system"] = []
                        elif isinstance(system_value, list):
                            new_system_list = []
                            for system_item in system_value:
                                if isinstance(system_item, dict) and "cache_control" in system_item:
                                    new_system_item = {k: v for k, v in system_item.items() if k != "cache_control"}
                                    if new_system_item: # [鲁棒性] 避免添加空字典
                                        new_system_list.append(new_system_item)
                                else:
                                    new_system_list.append(system_item)
                            temp_message["system"] = new_system_list
                        elif isinstance(system_value, dict):
                            new_system_dict = {k: v for k, v in system_value.items() if k != "cache_control"}
                            if "text" in new_system_dict or new_system_dict: # [鲁棒性] 确保不是空的 system
                                temp_message["system"] = new_system_dict
                            else:
                                temp_message["system"] = []
                    else:
                        temp_message["system"] = []
                    
                    if "cache_control" in temp_message:
                        del temp_message["cache_control"] # [鲁棒性] 确保顶层 cache_control 被移除

                    temp_messages_for_token_count.append(temp_message)
                
                token_count = litellm.token_counter(model=model_id_for_token_counting, messages=temp_messages_for_token_count)
                token_counting_successful = True
                
            except Exception as token_error: # [鲁棒性/错误处理] 捕获 token 计数错误，强制禁用缓存
                print(f"--- [HOOK WARNING] litellm.token_counter failed for model '{model_id_for_token_counting}': {token_error}. This is likely due to internal LiteLLM issues. Forcing token_count = 0 to disable caching.")
                traceback.print_exc()
                token_count = 0 
                token_counting_successful = False 

            min_required_tokens = GEMINI_CACHE_MINIMUMS.get(model_id_for_token_counting, DEFAULT_MIN_CACHE_TOKENS)
            
            print(f"--- [HOOK ANALYSIS] Tokens (LiteLLM estimate): {token_count}. Required: {min_required_tokens}. (Token counting successful: {token_counting_successful})")

            if token_count < min_required_tokens or not token_counting_successful:
                print("--- [HOOK ACTION] Token count too low (LiteLLM estimate) or token counting failed. Proactively DISABLING 'caching' related parameters for this Gemini request to prevent 400 error.\n")
                
                # 以下代码块都是为了在缓存被禁用时，彻底清除请求中的所有缓存相关参数，
                # 尤其是在遇到 400 错误之后添加的，是为了增强鲁棒性。
                if "litellm_params" in data and "caching" in data.get("litellm_params", {}):
                    del data["litellm_params"]["caching"]
                if "caching" in data:
                    del data["caching"]
                
                # 再次遍历消息，确保清理彻底
                for msg_data in data.get("messages", []):
                    if "content" in msg_data:
                        content_value = msg_data["content"]
                        if isinstance(content_value, list):
                            msg_data["content"] = [{k: v for k, v in item.items() if k != "cache_control"}
                                                   if isinstance(item, dict) and "cache_control" in item else item
                                                   for item in content_value]
                        elif isinstance(content_value, dict) and "cache_control" in content_value:
                            msg_data["content"] = {k: v for k, v in content_value.items() if k != "cache_control"}
                    
                    if "system" in msg_data:
                        system_value = msg_data["system"]
                        if isinstance(system_value, list):
                            msg_data["system"] = [{k: v for k, v in item.items() if k != "cache_control"}
                                                  if isinstance(item, dict) and "cache_control" in item else item
                                                  for item in system_value]
                        elif isinstance(system_value, dict) and "cache_control" in system_value:
                            msg_data["system"] = {k: v for k, v in system_value.items() if k != "cache_control"}
                    
                    if "cache_control" in msg_data:
                        del msg_data["cache_control"]
                
                system_value_data_to_process = data.get("system") # [鲁棒性] 处理顶层 system 字段
                if system_value_data_to_process is not None:
                    if isinstance(system_value_data_to_process, list):
                        data["system"] = [{k: v for k, v in item.items() if k != "cache_control"}
                                          if isinstance(item, dict) and "cache_control" in item else item
                                          for item in system_value_data_to_process]
                    elif isinstance(system_value_data_to_process, dict) and "cache_control" in system_value_data_to_process:
                        data["system"] = {k: v for k, v in system_value_data_to_process.items() if k != "cache_control"}

            else:
                print("--- [HOOK ACTION] Token count sufficient. Allowing client-provided 'caching' (if any) and 'cache_control' to pass through.\n")
                
        except Exception as e: # [鲁棒性/全局错误处理] 捕获所有未知错误，并执行清理操作
            print(f"--- [HOOK CRITICAL ERROR] An unexpected error occurred: {e} ---")
            traceback.print_exc()
            # 以下代码块是在发生任何意外错误时，执行的“安全网”清理操作，以尽量避免请求失败。
            if "litellm_params" in data and "caching" in data.get("litellm_params", {}):
                del data["litellm_params"]["caching"]
            if "caching" in data:
                del data["caching"]
            
            for msg_data in data.get("messages", []):
                if "content" in msg_data:
                    content_value = msg_data["content"]
                    if isinstance(content_value, list):
                        msg_data["content"] = [{k: v for k, v in item.items() if k != "cache_control"}
                                               if isinstance(item, dict) and "cache_control" in item else item
                                               for item in content_value]
                    elif isinstance(content_value, dict) and "cache_control" in content_value:
                        # [BUGFIX for Robustness] 原代码错误地修改了顶层 data["content"]，应修改 msg_data["content"]
                        msg_data["content"] = {k: v for k, v in content_value.items() if k != "cache_control"} 
                
                if "system" in msg_data:
                    system_value = msg_data["system"]
                    if isinstance(system_value, list):
                        msg_data["system"] = [{k: v for k, v in item.items() if k != "cache_control"}
                                              if isinstance(item, dict) and "cache_control" in item else item
                                              for item in system_value]
                    elif isinstance(system_value, dict) and "cache_control" in system_value:
                        # [BUGFIX for Robustness] 原代码错误地修改了顶层 data["system"]，应修改 msg_data["system"]
                        msg_data["system"] = {k: v for k, v in system_value.items() if k != "cache_control"}
                
                if "cache_control" in msg_data:
                    del msg_data["cache_control"]
            
            system_value_data_to_process = data.get("system")
            if system_value_data_to_process is not None:
                if isinstance(system_value_data_to_process, list):
                    data["system"] = [{k: v for k, v in item.items() if k != "cache_control"}
                                      if isinstance(item, dict) and "cache_control" in item else item
                                      for item in system_value_data_to_process]
                elif isinstance(system_value_data_to_process, dict) and "cache_control" in system_value_data_to_process:
                    data["system"] = {k: v for k, v in system_value_data_to_process.items() if k != "cache_control"}
        
        return data

gemini_cache_hook = GeminiCacheManager()